You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

954 lines
26 KiB

3 weeks ago
package ollama
import (
"bytes"
"cmp"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"net"
"net/http"
"net/http/httptest"
"os"
"reflect"
"strings"
"sync/atomic"
"testing"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/testutil"
)
func ExampleRegistry_cancelOnFirstError() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx = WithTrace(ctx, &Trace{
Update: func(l *Layer, n int64, err error) {
if err != nil {
// Discontinue pulling layers if there is an
// error instead of continuing to pull more
// data.
cancel()
}
},
})
var r Registry
if err := r.Pull(ctx, "model"); err != nil {
// panic for demo purposes
panic(err)
}
}
func TestManifestMarshalJSON(t *testing.T) {
// All manifests should contain an "empty" config object.
var m Manifest
data, err := json.Marshal(m)
if err != nil {
t.Fatal(err)
}
if !bytes.Contains(data, []byte(`"config":{"digest":"sha256:`)) {
t.Error("expected manifest to contain empty config")
t.Fatalf("got:\n%s", string(data))
}
}
var errRoundTrip = errors.New("forced roundtrip error")
type recordRoundTripper http.HandlerFunc
func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
w := httptest.NewRecorder()
rr(w, req)
if w.Code == 499 {
return nil, errRoundTrip
}
resp := w.Result()
// For some reason, Response.Request is not set by httptest.NewRecorder, so we
// set it manually.
resp.Request = req
return w.Result(), nil
}
// newClient constructs a cache with predefined manifests for testing. The manifests are:
//
// empty: no data
// zero: no layers
// single: one layer with the contents "exists"
// multiple: two layers with the contents "exists" and "here"
// notfound: a layer that does not exist in the cache
// null: one null layer (e.g. [null])
// sizemismatch: one valid layer, and one with a size mismatch (file size is less than the reported size)
// invalid: a layer with invalid JSON data
//
// Tests that want to ensure the client does not communicate with the upstream
// registry should pass a nil handler, which will cause a panic if
// communication is attempted.
//
// To simulate a network error, pass a handler that returns a 499 status code.
func newClient(t *testing.T, upstreamRegistry http.HandlerFunc) (*Registry, *blob.DiskCache) {
t.Helper()
c, err := blob.Open(t.TempDir())
if err != nil {
t.Fatal(err)
}
mklayer := func(data string) *Layer {
return &Layer{
Digest: importBytes(t, c, data),
Size: int64(len(data)),
}
}
r := &Registry{
Cache: c,
HTTPClient: &http.Client{
Transport: recordRoundTripper(upstreamRegistry),
},
}
link := func(name string, manifest string) {
n, err := r.parseName(name)
if err != nil {
panic(err)
}
d, err := c.Import(bytes.NewReader([]byte(manifest)), int64(len(manifest)))
if err != nil {
panic(err)
}
if err := c.Link(n.String(), d); err != nil {
panic(err)
}
}
commit := func(name string, layers ...*Layer) {
t.Helper()
data, err := json.Marshal(&Manifest{Layers: layers})
if err != nil {
t.Fatal(err)
}
link(name, string(data))
}
link("empty", "")
commit("zero")
commit("single", mklayer("exists"))
commit("multiple", mklayer("exists"), mklayer("present"))
commit("notfound", &Layer{Digest: blob.DigestFromBytes("notfound"), Size: int64(len("notfound"))})
commit("null", nil)
commit("sizemismatch", mklayer("exists"), &Layer{Digest: blob.DigestFromBytes("present"), Size: 499})
link("invalid", "!!!!!")
return r, c
}
func okHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
func checkErrCode(t *testing.T, err error, status int, code string) {
t.Helper()
var e *Error
if !errors.As(err, &e) || e.status != status || e.Code != code {
t.Errorf("err = %v; want %v %v", err, status, code)
}
}
func importBytes(t *testing.T, c *blob.DiskCache, data string) blob.Digest {
d, err := c.Import(strings.NewReader(data), int64(len(data)))
if err != nil {
t.Fatal(err)
}
return d
}
func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) {
t := &Trace{Update: func(*Layer, int64, error) { panic("unexpected") }}
return WithTrace(ctx, t), t
}
func TestPushZero(t *testing.T) {
rc, _ := newClient(t, okHandler)
err := rc.Push(t.Context(), "empty", nil)
if !errors.Is(err, ErrManifestInvalid) {
t.Errorf("err = %v; want %v", err, ErrManifestInvalid)
}
}
func TestPushSingle(t *testing.T) {
rc, _ := newClient(t, okHandler)
err := rc.Push(t.Context(), "single", nil)
testutil.Check(t, err)
}
func TestPushMultiple(t *testing.T) {
rc, _ := newClient(t, okHandler)
err := rc.Push(t.Context(), "multiple", nil)
testutil.Check(t, err)
}
func TestPushNotFound(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
t.Errorf("unexpected request: %v", r)
})
err := rc.Push(t.Context(), "notfound", nil)
if !errors.Is(err, fs.ErrNotExist) {
t.Errorf("err = %v; want %v", err, fs.ErrNotExist)
}
}
func TestPushNullLayer(t *testing.T) {
rc, _ := newClient(t, nil)
err := rc.Push(t.Context(), "null", nil)
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
t.Errorf("err = %v; want invalid manifest", err)
}
}
func TestPushSizeMismatch(t *testing.T) {
rc, _ := newClient(t, nil)
ctx, _ := withTraceUnexpected(t.Context())
got := rc.Push(ctx, "sizemismatch", nil)
if got == nil || !strings.Contains(got.Error(), "size mismatch") {
t.Errorf("err = %v; want size mismatch", got)
}
}
func TestPushInvalid(t *testing.T) {
rc, _ := newClient(t, nil)
err := rc.Push(t.Context(), "invalid", nil)
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
t.Errorf("err = %v; want invalid manifest", err)
}
}
func TestPushExistsAtRemote(t *testing.T) {
var pushed bool
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/uploads/") {
if !pushed {
// First push. Return an uploadURL.
pushed = true
w.Header().Set("Location", "http://blob.store/blobs/123")
return
}
w.WriteHeader(http.StatusAccepted)
return
}
io.Copy(io.Discard, r.Body)
w.WriteHeader(http.StatusOK)
})
rc.MaxStreams = 1 // prevent concurrent uploads
var errs []error
ctx := WithTrace(t.Context(), &Trace{
Update: func(_ *Layer, n int64, err error) {
// uploading one at a time so no need to lock
errs = append(errs, err)
},
})
check := testutil.Checker(t)
err := rc.Push(ctx, "single", nil)
check(err)
if !errors.Is(errors.Join(errs...), nil) {
t.Errorf("errs = %v; want %v", errs, []error{ErrCached})
}
err = rc.Push(ctx, "single", nil)
check(err)
}
func TestPushRemoteError(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
w.WriteHeader(500)
io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`)
return
}
})
got := rc.Push(t.Context(), "single", nil)
checkErrCode(t, got, 500, "blob_error")
}
func TestPushLocationError(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Location", ":///x")
w.WriteHeader(http.StatusAccepted)
})
got := rc.Push(t.Context(), "single", nil)
wantContains := "invalid upload URL"
if got == nil || !strings.Contains(got.Error(), wantContains) {
t.Errorf("err = %v; want to contain %v", got, wantContains)
}
}
func TestPushUploadRoundtripError(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if r.Host == "blob.store" {
w.WriteHeader(499) // force RoundTrip error on upload
return
}
w.Header().Set("Location", "http://blob.store/blobs/123")
})
got := rc.Push(t.Context(), "single", nil)
if !errors.Is(got, errRoundTrip) {
t.Errorf("got = %v; want %v", got, errRoundTrip)
}
}
func TestPushUploadFileOpenError(t *testing.T) {
rc, c := newClient(t, okHandler)
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, _ int64, err error) {
// Remove the file just before it is opened for upload,
// but after the initial Stat that happens before the
// upload starts
os.Remove(c.GetFile(l.Digest))
},
})
got := rc.Push(ctx, "single", nil)
if !errors.Is(got, fs.ErrNotExist) {
t.Errorf("got = %v; want fs.ErrNotExist", got)
}
}
func TestPushCommitRoundtripError(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
panic("unexpected")
}
w.WriteHeader(499) // force RoundTrip error
})
err := rc.Push(t.Context(), "zero", nil)
if !errors.Is(err, errRoundTrip) {
t.Errorf("err = %v; want %v", err, errRoundTrip)
}
}
func TestRegistryPullInvalidName(t *testing.T) {
rc, _ := newRegistryClient(t, nil)
err := rc.Pull(t.Context(), "://")
if !errors.Is(err, ErrNameInvalid) {
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
}
}
func TestRegistryPullInvalidManifest(t *testing.T) {
cases := []string{
"",
"null",
"!!!",
`{"layers":[]}`,
}
for _, resp := range cases {
rc, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, resp)
})
err := rc.Pull(t.Context(), "http://example.com/a/b")
if !errors.Is(err, ErrManifestInvalid) {
t.Errorf("err = %v; want invalid manifest", err)
}
}
}
func TestRegistryResolveByDigest(t *testing.T) {
check := testutil.Checker(t)
exists := blob.DigestFromBytes("exists")
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v2/alice/palace/blobs/"+exists.String() {
w.WriteHeader(499) // should not hit manifest endpoint
}
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":5}]}`, exists)
})
_, err := rc.Resolve(t.Context(), "alice/palace@"+exists.String())
check(err)
}
func TestInsecureSkipVerify(t *testing.T) {
exists := blob.DigestFromBytes("exists")
s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":5}]}`, exists)
}))
defer s.Close()
const name = "library/insecure"
var rc Registry
url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name)
_, err := rc.Resolve(t.Context(), url)
if err == nil || !strings.Contains(err.Error(), "failed to verify") {
t.Errorf("err = %v; want cert verification failure", err)
}
url = fmt.Sprintf("https+insecure://%s/%s", s.Listener.Addr(), name)
_, err = rc.Resolve(t.Context(), url)
testutil.Check(t, err)
}
func TestErrorUnmarshal(t *testing.T) {
cases := []struct {
name string
data string
want *Error
wantErr bool
}{
{
name: "errors empty",
data: `{"errors":[]}`,
wantErr: true,
},
{
name: "errors empty",
data: `{"errors":[]}`,
wantErr: true,
},
{
name: "errors single",
data: `{"errors":[{"code":"blob_unknown"}]}`,
want: &Error{Code: "blob_unknown", Message: ""},
},
{
name: "errors multiple",
data: `{"errors":[{"code":"blob_unknown"},{"code":"blob_error"}]}`,
want: &Error{Code: "blob_unknown", Message: ""},
},
{
name: "error empty",
data: `{"error":""}`,
wantErr: true,
},
{
name: "error very empty",
data: `{}`,
wantErr: true,
},
{
name: "error message",
data: `{"error":"message", "code":"code"}`,
want: &Error{Code: "code", Message: "message"},
},
{
name: "invalid value",
data: `{"error": 1}`,
wantErr: true,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
var got Error
err := json.Unmarshal([]byte(tt.data), &got)
if err != nil {
if tt.wantErr {
return
}
t.Errorf("Unmarshal() error = %v", err)
// fallthrough and check got
}
if tt.want == nil {
tt.want = &Error{}
}
if !reflect.DeepEqual(got, *tt.want) {
t.Errorf("got = %v; want %v", got, *tt.want)
}
})
}
}
// TestParseNameErrors tests that parseName returns errors messages with enough
// detail for users to debug naming issues they may encounter. Previous to this
// test, the error messages were not very helpful and each problem was reported
// as the same message.
//
// It is only for testing error messages, not that all invalids and valids are
// covered. Those are in other tests for names.Name and blob.Digest.
func TestParseNameExtendedErrors(t *testing.T) {
cases := []struct {
name string
err error
want string
}{}
var r Registry
for _, tt := range cases {
_, _, _, err := r.parseNameExtended(tt.name)
if !errors.Is(err, tt.err) {
t.Errorf("[%s]: err = %v; want %v", tt.name, err, tt.err)
}
if err != nil && !strings.Contains(err.Error(), tt.want) {
t.Errorf("[%s]: err =\n\t%v\nwant\n\t%v", tt.name, err, tt.want)
}
}
}
func TestParseNameExtended(t *testing.T) {
cases := []struct {
in string
scheme string
name string
digest string
err string
}{
{in: "http://m", scheme: "http", name: "m"},
{in: "https+insecure://m", scheme: "https+insecure", name: "m"},
{in: "http+insecure://m", err: "unsupported scheme"},
{in: "http://m@sha256:1111111111111111111111111111111111111111111111111111111111111111", scheme: "http", name: "m", digest: "sha256:1111111111111111111111111111111111111111111111111111111111111111"},
{in: "", err: "invalid or missing name"},
{in: "m", scheme: "https", name: "m"},
{in: "://", err: "invalid or missing name"},
{in: "@sha256:deadbeef", err: "invalid digest"},
{in: "@sha256:deadbeef@sha256:deadbeef", err: "invalid digest"},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
var r Registry
scheme, n, digest, err := r.parseNameExtended(tt.in)
if err != nil {
if tt.err == "" {
t.Errorf("err = %v; want nil", err)
} else if !strings.Contains(err.Error(), tt.err) {
t.Errorf("err = %v; want %q", err, tt.err)
}
} else if tt.err != "" {
t.Errorf("err = nil; want %q", tt.err)
}
if err == nil && !n.IsFullyQualified() {
t.Errorf("name = %q; want fully qualified", n)
}
if scheme != tt.scheme {
t.Errorf("scheme = %q; want %q", scheme, tt.scheme)
}
// smoke-test name is superset of tt.name
if !strings.Contains(n.String(), tt.name) {
t.Errorf("name = %q; want %q", n, tt.name)
}
tt.digest = cmp.Or(tt.digest, (&blob.Digest{}).String())
if digest.String() != tt.digest {
t.Errorf("digest = %q; want %q", digest, tt.digest)
}
})
}
}
func TestUnlink(t *testing.T) {
t.Run("found by name", func(t *testing.T) {
check := testutil.Checker(t)
rc, _ := newRegistryClient(t, nil)
// make a blob and link it
d := blob.DigestFromBytes("{}")
err := blob.PutBytes(rc.Cache, d, "{}")
check(err)
err = rc.Cache.Link("registry.ollama.ai/library/single:latest", d)
check(err)
// confirm linked
_, err = rc.ResolveLocal("single")
check(err)
// unlink
_, err = rc.Unlink("single")
check(err)
// confirm unlinked
_, err = rc.ResolveLocal("single")
if !errors.Is(err, fs.ErrNotExist) {
t.Errorf("err = %v; want fs.ErrNotExist", err)
}
})
t.Run("not found by name", func(t *testing.T) {
rc, _ := newRegistryClient(t, nil)
ok, err := rc.Unlink("manifestNotFound")
if err != nil {
t.Fatal(err)
}
if ok {
t.Error("expected not found")
}
})
}
// Many tests from here out, in this file are based on a single blob, "abc",
// with the checksum of its sha256 hash. The checksum is:
//
// "abc" -> sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad
//
// Using the literal value instead of a constant with fmt.Xprintf calls proved
// to be the most readable and maintainable approach. The sum is consistently
// used in the tests and unique so searches do not yield false positives.
func checkRequest(t *testing.T, req *http.Request, method, path string) {
t.Helper()
if got := req.URL.Path; got != path {
t.Errorf("URL = %q, want %q", got, path)
}
if req.Method != method {
t.Errorf("Method = %q, want %q", req.Method, method)
}
}
func newRegistryClient(t *testing.T, upstream http.HandlerFunc) (*Registry, context.Context) {
s := httptest.NewServer(upstream)
t.Cleanup(s.Close)
cache, err := blob.Open(t.TempDir())
if err != nil {
t.Fatal(err)
}
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Log("trace:", l.Digest.Short(), n, err)
},
})
rc := &Registry{
Cache: cache,
HTTPClient: &http.Client{Transport: &http.Transport{
Dial: func(network, addr string) (net.Conn, error) {
return net.Dial(network, s.Listener.Addr().String())
},
}},
}
return rc, ctx
}
func TestPullChunked(t *testing.T) {
var steps atomic.Int64
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
switch steps.Add(1) {
case 1:
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
case 2:
checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
case 3, 4:
checkRequest(t, r, "GET", "/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
switch rng := r.Header.Get("Range"); rng {
case "bytes=0-1":
io.WriteString(w, "ab")
case "bytes=2-2":
t.Logf("writing c")
io.WriteString(w, "c")
default:
t.Errorf("unexpected range %q", rng)
}
default:
t.Errorf("unexpected steps %d: %v", steps.Load(), r)
http.Error(w, "unexpected steps", http.StatusInternalServerError)
}
})
c.ChunkingThreshold = 1 // force chunking
err := c.Pull(ctx, "http://o.com/library/abc")
testutil.Check(t, err)
_, err = c.Cache.Resolve("o.com/library/abc:latest")
testutil.Check(t, err)
if g := steps.Load(); g != 4 {
t.Fatalf("got %d steps, want 4", g)
}
}
func TestPullCached(t *testing.T) {
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
})
check := testutil.Checker(t)
// Premeptively cache the blob
d, err := blob.ParseDigest("sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
check(err)
err = blob.PutBytes(c.Cache, d, []byte("abc"))
check(err)
// Pull only the manifest, which should be enough to resolve the cached blob
err = c.Pull(ctx, "http://o.com/library/abc")
check(err)
}
func TestPullManifestError(t *testing.T) {
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
w.WriteHeader(http.StatusNotFound)
io.WriteString(w, `{"errors":[{"code":"MANIFEST_UNKNOWN"}]}`)
})
err := c.Pull(ctx, "http://o.com/library/abc")
if err == nil {
t.Fatalf("expected error")
}
var got *Error
if !errors.Is(err, ErrModelNotFound) {
t.Fatalf("err = %v, want %v", got, ErrModelNotFound)
}
}
func TestPullLayerError(t *testing.T) {
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `!`)
})
err := c.Pull(ctx, "http://o.com/library/abc")
if err == nil {
t.Fatalf("expected error")
}
var want *json.SyntaxError
if !errors.As(err, &want) {
t.Fatalf("err = %T, want %T", err, want)
}
}
func TestPullLayerChecksumError(t *testing.T) {
var step atomic.Int64
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
switch step.Add(1) {
case 1:
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
case 2:
checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
case 3:
w.WriteHeader(http.StatusNotFound)
io.WriteString(w, `{"errors":[{"code":"BLOB_UNKNOWN"}]}`)
case 4:
io.WriteString(w, "c")
default:
t.Errorf("unexpected steps %d: %v", step.Load(), r)
http.Error(w, "unexpected steps", http.StatusInternalServerError)
}
})
c.MaxStreams = 1
c.ChunkingThreshold = 1 // force chunking
var written atomic.Int64
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Log("trace:", l.Digest.Short(), n, err)
written.Add(n)
},
})
err := c.Pull(ctx, "http://o.com/library/abc")
var got *Error
if !errors.As(err, &got) || got.Code != "BLOB_UNKNOWN" {
t.Fatalf("err = %v, want %v", err, got)
}
if g := written.Load(); g != 1 {
t.Fatalf("wrote %d bytes, want 1", g)
}
}
func TestPullChunksumStreamError(t *testing.T) {
var step atomic.Int64
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
switch step.Add(1) {
case 1:
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
case 2:
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
// Write one valid chunksum and one invalid chunksum
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab")) // valid
fmt.Fprint(w, "sha256:!") // invalid
case 3:
io.WriteString(w, "ab")
default:
t.Errorf("unexpected steps %d: %v", step.Load(), r)
http.Error(w, "unexpected steps", http.StatusInternalServerError)
}
})
c.ChunkingThreshold = 1 // force chunking
got := c.Pull(ctx, "http://o.com/library/abc")
if !errors.Is(got, ErrIncomplete) {
t.Fatalf("err = %v, want %v", got, ErrIncomplete)
}
}
type flushAfterWriter struct {
w io.Writer
}
func (f *flushAfterWriter) Write(p []byte) (n int, err error) {
n, err = f.w.Write(p)
f.w.(http.Flusher).Flush() // panic if not a flusher
return
}
func TestPullChunksumStreaming(t *testing.T) {
csr, csw := io.Pipe()
defer csw.Close()
var step atomic.Int64
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
switch step.Add(1) {
case 1:
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
case 2:
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
fw := &flushAfterWriter{w} // ensure client gets data as it arrives by aggressively flushing
_, err := io.Copy(fw, csr)
if err != nil {
t.Errorf("copy: %v", err)
}
case 3:
io.WriteString(w, "ab")
case 4:
io.WriteString(w, "c")
default:
t.Errorf("unexpected steps %d: %v", step.Load(), r)
http.Error(w, "unexpected steps", http.StatusInternalServerError)
}
})
c.ChunkingThreshold = 1 // force chunking
update := make(chan int64, 1)
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Log("trace:", l.Digest.Short(), n, err)
if n > 0 {
update <- n
}
},
})
errc := make(chan error, 1)
go func() {
errc <- c.Pull(ctx, "http://o.com/library/abc")
}()
// Send first chunksum and ensure it kicks off work immediately
fmt.Fprintf(csw, "%s 0-1\n", blob.DigestFromBytes("ab"))
if g := <-update; g != 2 {
t.Fatalf("got %d, want 2", g)
}
// now send the second chunksum and ensure it kicks off work immediately
fmt.Fprintf(csw, "%s 2-2\n", blob.DigestFromBytes("c"))
if g := <-update; g != 3 {
t.Fatalf("got %d, want 3", g)
}
csw.Close()
testutil.Check(t, <-errc)
}
func TestPullChunksumsCached(t *testing.T) {
var step atomic.Int64
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
switch step.Add(1) {
case 1:
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
case 2:
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
case 3, 4:
switch rng := r.Header.Get("Range"); rng {
case "bytes=0-1":
io.WriteString(w, "ab")
case "bytes=2-2":
io.WriteString(w, "c")
default:
t.Errorf("unexpected range %q", rng)
}
default:
t.Errorf("unexpected steps %d: %v", step.Load(), r)
http.Error(w, "unexpected steps", http.StatusInternalServerError)
}
})
c.MaxStreams = 1 // force serial processing of chunksums
c.ChunkingThreshold = 1 // force chunking
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
// Cancel the pull after the first chunksum is processed, but before
// the second chunksum is processed (which is waiting because
// MaxStreams=1). This should cause the second chunksum to error out
// leaving the blob incomplete.
ctx = WithTrace(ctx, &Trace{
Update: func(l *Layer, n int64, err error) {
if n > 0 {
cancel()
}
},
})
err := c.Pull(ctx, "http://o.com/library/abc")
if !errors.Is(err, context.Canceled) {
t.Fatalf("err = %v, want %v", err, context.Canceled)
}
_, err = c.Cache.Resolve("o.com/library/abc:latest")
if !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("err = %v, want nil", err)
}
// Reset state and pull again to ensure the blob chunks that should
// have been cached are, and the remaining chunk was downloaded, making
// the blob complete.
step.Store(0)
var written atomic.Int64
var cached atomic.Int64
ctx = WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Log("trace:", l.Digest.Short(), n, err)
if errors.Is(err, ErrCached) {
cached.Add(n)
}
written.Add(n)
},
})
check := testutil.Checker(t)
err = c.Pull(ctx, "http://o.com/library/abc")
check(err)
_, err = c.Cache.Resolve("o.com/library/abc:latest")
check(err)
if g := written.Load(); g != 5 {
t.Fatalf("wrote %d bytes, want 3", g)
}
if g := cached.Load(); g != 2 { // "ab" should have been cached
t.Fatalf("cached %d bytes, want 5", g)
}
}