You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
954 lines
26 KiB
954 lines
26 KiB
package ollama
|
|
|
|
import (
|
|
"bytes"
|
|
"cmp"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/fs"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"reflect"
|
|
"strings"
|
|
"sync/atomic"
|
|
"testing"
|
|
|
|
"github.com/ollama/ollama/server/internal/cache/blob"
|
|
"github.com/ollama/ollama/server/internal/testutil"
|
|
)
|
|
|
|
func ExampleRegistry_cancelOnFirstError() {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
ctx = WithTrace(ctx, &Trace{
|
|
Update: func(l *Layer, n int64, err error) {
|
|
if err != nil {
|
|
// Discontinue pulling layers if there is an
|
|
// error instead of continuing to pull more
|
|
// data.
|
|
cancel()
|
|
}
|
|
},
|
|
})
|
|
|
|
var r Registry
|
|
if err := r.Pull(ctx, "model"); err != nil {
|
|
// panic for demo purposes
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
func TestManifestMarshalJSON(t *testing.T) {
|
|
// All manifests should contain an "empty" config object.
|
|
var m Manifest
|
|
data, err := json.Marshal(m)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !bytes.Contains(data, []byte(`"config":{"digest":"sha256:`)) {
|
|
t.Error("expected manifest to contain empty config")
|
|
t.Fatalf("got:\n%s", string(data))
|
|
}
|
|
}
|
|
|
|
var errRoundTrip = errors.New("forced roundtrip error")
|
|
|
|
type recordRoundTripper http.HandlerFunc
|
|
|
|
func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
w := httptest.NewRecorder()
|
|
rr(w, req)
|
|
if w.Code == 499 {
|
|
return nil, errRoundTrip
|
|
}
|
|
resp := w.Result()
|
|
// For some reason, Response.Request is not set by httptest.NewRecorder, so we
|
|
// set it manually.
|
|
resp.Request = req
|
|
return w.Result(), nil
|
|
}
|
|
|
|
// newClient constructs a cache with predefined manifests for testing. The manifests are:
|
|
//
|
|
// empty: no data
|
|
// zero: no layers
|
|
// single: one layer with the contents "exists"
|
|
// multiple: two layers with the contents "exists" and "here"
|
|
// notfound: a layer that does not exist in the cache
|
|
// null: one null layer (e.g. [null])
|
|
// sizemismatch: one valid layer, and one with a size mismatch (file size is less than the reported size)
|
|
// invalid: a layer with invalid JSON data
|
|
//
|
|
// Tests that want to ensure the client does not communicate with the upstream
|
|
// registry should pass a nil handler, which will cause a panic if
|
|
// communication is attempted.
|
|
//
|
|
// To simulate a network error, pass a handler that returns a 499 status code.
|
|
func newClient(t *testing.T, upstreamRegistry http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
|
t.Helper()
|
|
|
|
c, err := blob.Open(t.TempDir())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
mklayer := func(data string) *Layer {
|
|
return &Layer{
|
|
Digest: importBytes(t, c, data),
|
|
Size: int64(len(data)),
|
|
}
|
|
}
|
|
|
|
r := &Registry{
|
|
Cache: c,
|
|
HTTPClient: &http.Client{
|
|
Transport: recordRoundTripper(upstreamRegistry),
|
|
},
|
|
}
|
|
|
|
link := func(name string, manifest string) {
|
|
n, err := r.parseName(name)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
d, err := c.Import(bytes.NewReader([]byte(manifest)), int64(len(manifest)))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
if err := c.Link(n.String(), d); err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
commit := func(name string, layers ...*Layer) {
|
|
t.Helper()
|
|
data, err := json.Marshal(&Manifest{Layers: layers})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
link(name, string(data))
|
|
}
|
|
|
|
link("empty", "")
|
|
commit("zero")
|
|
commit("single", mklayer("exists"))
|
|
commit("multiple", mklayer("exists"), mklayer("present"))
|
|
commit("notfound", &Layer{Digest: blob.DigestFromBytes("notfound"), Size: int64(len("notfound"))})
|
|
commit("null", nil)
|
|
commit("sizemismatch", mklayer("exists"), &Layer{Digest: blob.DigestFromBytes("present"), Size: 499})
|
|
link("invalid", "!!!!!")
|
|
|
|
return r, c
|
|
}
|
|
|
|
func okHandler(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
|
|
func checkErrCode(t *testing.T, err error, status int, code string) {
|
|
t.Helper()
|
|
var e *Error
|
|
if !errors.As(err, &e) || e.status != status || e.Code != code {
|
|
t.Errorf("err = %v; want %v %v", err, status, code)
|
|
}
|
|
}
|
|
|
|
func importBytes(t *testing.T, c *blob.DiskCache, data string) blob.Digest {
|
|
d, err := c.Import(strings.NewReader(data), int64(len(data)))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return d
|
|
}
|
|
|
|
func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) {
|
|
t := &Trace{Update: func(*Layer, int64, error) { panic("unexpected") }}
|
|
return WithTrace(ctx, t), t
|
|
}
|
|
|
|
func TestPushZero(t *testing.T) {
|
|
rc, _ := newClient(t, okHandler)
|
|
err := rc.Push(t.Context(), "empty", nil)
|
|
if !errors.Is(err, ErrManifestInvalid) {
|
|
t.Errorf("err = %v; want %v", err, ErrManifestInvalid)
|
|
}
|
|
}
|
|
|
|
func TestPushSingle(t *testing.T) {
|
|
rc, _ := newClient(t, okHandler)
|
|
err := rc.Push(t.Context(), "single", nil)
|
|
testutil.Check(t, err)
|
|
}
|
|
|
|
func TestPushMultiple(t *testing.T) {
|
|
rc, _ := newClient(t, okHandler)
|
|
err := rc.Push(t.Context(), "multiple", nil)
|
|
testutil.Check(t, err)
|
|
}
|
|
|
|
func TestPushNotFound(t *testing.T) {
|
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
t.Errorf("unexpected request: %v", r)
|
|
})
|
|
err := rc.Push(t.Context(), "notfound", nil)
|
|
if !errors.Is(err, fs.ErrNotExist) {
|
|
t.Errorf("err = %v; want %v", err, fs.ErrNotExist)
|
|
}
|
|
}
|
|
|
|
func TestPushNullLayer(t *testing.T) {
|
|
rc, _ := newClient(t, nil)
|
|
err := rc.Push(t.Context(), "null", nil)
|
|
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
|
|
t.Errorf("err = %v; want invalid manifest", err)
|
|
}
|
|
}
|
|
|
|
func TestPushSizeMismatch(t *testing.T) {
|
|
rc, _ := newClient(t, nil)
|
|
ctx, _ := withTraceUnexpected(t.Context())
|
|
got := rc.Push(ctx, "sizemismatch", nil)
|
|
if got == nil || !strings.Contains(got.Error(), "size mismatch") {
|
|
t.Errorf("err = %v; want size mismatch", got)
|
|
}
|
|
}
|
|
|
|
func TestPushInvalid(t *testing.T) {
|
|
rc, _ := newClient(t, nil)
|
|
err := rc.Push(t.Context(), "invalid", nil)
|
|
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
|
|
t.Errorf("err = %v; want invalid manifest", err)
|
|
}
|
|
}
|
|
|
|
func TestPushExistsAtRemote(t *testing.T) {
|
|
var pushed bool
|
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
if strings.Contains(r.URL.Path, "/uploads/") {
|
|
if !pushed {
|
|
// First push. Return an uploadURL.
|
|
pushed = true
|
|
w.Header().Set("Location", "http://blob.store/blobs/123")
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusAccepted)
|
|
return
|
|
}
|
|
|
|
io.Copy(io.Discard, r.Body)
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
rc.MaxStreams = 1 // prevent concurrent uploads
|
|
|
|
var errs []error
|
|
ctx := WithTrace(t.Context(), &Trace{
|
|
Update: func(_ *Layer, n int64, err error) {
|
|
// uploading one at a time so no need to lock
|
|
errs = append(errs, err)
|
|
},
|
|
})
|
|
|
|
check := testutil.Checker(t)
|
|
|
|
err := rc.Push(ctx, "single", nil)
|
|
check(err)
|
|
|
|
if !errors.Is(errors.Join(errs...), nil) {
|
|
t.Errorf("errs = %v; want %v", errs, []error{ErrCached})
|
|
}
|
|
|
|
err = rc.Push(ctx, "single", nil)
|
|
check(err)
|
|
}
|
|
|
|
func TestPushRemoteError(t *testing.T) {
|
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
if strings.Contains(r.URL.Path, "/blobs/") {
|
|
w.WriteHeader(500)
|
|
io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`)
|
|
return
|
|
}
|
|
})
|
|
got := rc.Push(t.Context(), "single", nil)
|
|
checkErrCode(t, got, 500, "blob_error")
|
|
}
|
|
|
|
func TestPushLocationError(t *testing.T) {
|
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Location", ":///x")
|
|
w.WriteHeader(http.StatusAccepted)
|
|
})
|
|
got := rc.Push(t.Context(), "single", nil)
|
|
wantContains := "invalid upload URL"
|
|
if got == nil || !strings.Contains(got.Error(), wantContains) {
|
|
t.Errorf("err = %v; want to contain %v", got, wantContains)
|
|
}
|
|
}
|
|
|
|
func TestPushUploadRoundtripError(t *testing.T) {
|
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Host == "blob.store" {
|
|
w.WriteHeader(499) // force RoundTrip error on upload
|
|
return
|
|
}
|
|
w.Header().Set("Location", "http://blob.store/blobs/123")
|
|
})
|
|
got := rc.Push(t.Context(), "single", nil)
|
|
if !errors.Is(got, errRoundTrip) {
|
|
t.Errorf("got = %v; want %v", got, errRoundTrip)
|
|
}
|
|
}
|
|
|
|
func TestPushUploadFileOpenError(t *testing.T) {
|
|
rc, c := newClient(t, okHandler)
|
|
ctx := WithTrace(t.Context(), &Trace{
|
|
Update: func(l *Layer, _ int64, err error) {
|
|
// Remove the file just before it is opened for upload,
|
|
// but after the initial Stat that happens before the
|
|
// upload starts
|
|
os.Remove(c.GetFile(l.Digest))
|
|
},
|
|
})
|
|
got := rc.Push(ctx, "single", nil)
|
|
if !errors.Is(got, fs.ErrNotExist) {
|
|
t.Errorf("got = %v; want fs.ErrNotExist", got)
|
|
}
|
|
}
|
|
|
|
func TestPushCommitRoundtripError(t *testing.T) {
|
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
if strings.Contains(r.URL.Path, "/blobs/") {
|
|
panic("unexpected")
|
|
}
|
|
w.WriteHeader(499) // force RoundTrip error
|
|
})
|
|
err := rc.Push(t.Context(), "zero", nil)
|
|
if !errors.Is(err, errRoundTrip) {
|
|
t.Errorf("err = %v; want %v", err, errRoundTrip)
|
|
}
|
|
}
|
|
|
|
func TestRegistryPullInvalidName(t *testing.T) {
|
|
rc, _ := newRegistryClient(t, nil)
|
|
err := rc.Pull(t.Context(), "://")
|
|
if !errors.Is(err, ErrNameInvalid) {
|
|
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
|
|
}
|
|
}
|
|
|
|
func TestRegistryPullInvalidManifest(t *testing.T) {
|
|
cases := []string{
|
|
"",
|
|
"null",
|
|
"!!!",
|
|
`{"layers":[]}`,
|
|
}
|
|
|
|
for _, resp := range cases {
|
|
rc, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
io.WriteString(w, resp)
|
|
})
|
|
err := rc.Pull(t.Context(), "http://example.com/a/b")
|
|
if !errors.Is(err, ErrManifestInvalid) {
|
|
t.Errorf("err = %v; want invalid manifest", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestRegistryResolveByDigest(t *testing.T) {
|
|
check := testutil.Checker(t)
|
|
|
|
exists := blob.DigestFromBytes("exists")
|
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/v2/alice/palace/blobs/"+exists.String() {
|
|
w.WriteHeader(499) // should not hit manifest endpoint
|
|
}
|
|
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":5}]}`, exists)
|
|
})
|
|
|
|
_, err := rc.Resolve(t.Context(), "alice/palace@"+exists.String())
|
|
check(err)
|
|
}
|
|
|
|
func TestInsecureSkipVerify(t *testing.T) {
|
|
exists := blob.DigestFromBytes("exists")
|
|
|
|
s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":5}]}`, exists)
|
|
}))
|
|
defer s.Close()
|
|
|
|
const name = "library/insecure"
|
|
|
|
var rc Registry
|
|
url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name)
|
|
_, err := rc.Resolve(t.Context(), url)
|
|
if err == nil || !strings.Contains(err.Error(), "failed to verify") {
|
|
t.Errorf("err = %v; want cert verification failure", err)
|
|
}
|
|
|
|
url = fmt.Sprintf("https+insecure://%s/%s", s.Listener.Addr(), name)
|
|
_, err = rc.Resolve(t.Context(), url)
|
|
testutil.Check(t, err)
|
|
}
|
|
|
|
func TestErrorUnmarshal(t *testing.T) {
|
|
cases := []struct {
|
|
name string
|
|
data string
|
|
want *Error
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "errors empty",
|
|
data: `{"errors":[]}`,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "errors empty",
|
|
data: `{"errors":[]}`,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "errors single",
|
|
data: `{"errors":[{"code":"blob_unknown"}]}`,
|
|
want: &Error{Code: "blob_unknown", Message: ""},
|
|
},
|
|
{
|
|
name: "errors multiple",
|
|
data: `{"errors":[{"code":"blob_unknown"},{"code":"blob_error"}]}`,
|
|
want: &Error{Code: "blob_unknown", Message: ""},
|
|
},
|
|
{
|
|
name: "error empty",
|
|
data: `{"error":""}`,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "error very empty",
|
|
data: `{}`,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "error message",
|
|
data: `{"error":"message", "code":"code"}`,
|
|
want: &Error{Code: "code", Message: "message"},
|
|
},
|
|
{
|
|
name: "invalid value",
|
|
data: `{"error": 1}`,
|
|
wantErr: true,
|
|
},
|
|
}
|
|
for _, tt := range cases {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
var got Error
|
|
err := json.Unmarshal([]byte(tt.data), &got)
|
|
if err != nil {
|
|
if tt.wantErr {
|
|
return
|
|
}
|
|
t.Errorf("Unmarshal() error = %v", err)
|
|
// fallthrough and check got
|
|
}
|
|
if tt.want == nil {
|
|
tt.want = &Error{}
|
|
}
|
|
if !reflect.DeepEqual(got, *tt.want) {
|
|
t.Errorf("got = %v; want %v", got, *tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestParseNameErrors tests that parseName returns errors messages with enough
|
|
// detail for users to debug naming issues they may encounter. Previous to this
|
|
// test, the error messages were not very helpful and each problem was reported
|
|
// as the same message.
|
|
//
|
|
// It is only for testing error messages, not that all invalids and valids are
|
|
// covered. Those are in other tests for names.Name and blob.Digest.
|
|
func TestParseNameExtendedErrors(t *testing.T) {
|
|
cases := []struct {
|
|
name string
|
|
err error
|
|
want string
|
|
}{}
|
|
|
|
var r Registry
|
|
for _, tt := range cases {
|
|
_, _, _, err := r.parseNameExtended(tt.name)
|
|
if !errors.Is(err, tt.err) {
|
|
t.Errorf("[%s]: err = %v; want %v", tt.name, err, tt.err)
|
|
}
|
|
if err != nil && !strings.Contains(err.Error(), tt.want) {
|
|
t.Errorf("[%s]: err =\n\t%v\nwant\n\t%v", tt.name, err, tt.want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestParseNameExtended(t *testing.T) {
|
|
cases := []struct {
|
|
in string
|
|
scheme string
|
|
name string
|
|
digest string
|
|
err string
|
|
}{
|
|
{in: "http://m", scheme: "http", name: "m"},
|
|
{in: "https+insecure://m", scheme: "https+insecure", name: "m"},
|
|
{in: "http+insecure://m", err: "unsupported scheme"},
|
|
|
|
{in: "http://m@sha256:1111111111111111111111111111111111111111111111111111111111111111", scheme: "http", name: "m", digest: "sha256:1111111111111111111111111111111111111111111111111111111111111111"},
|
|
|
|
{in: "", err: "invalid or missing name"},
|
|
{in: "m", scheme: "https", name: "m"},
|
|
{in: "://", err: "invalid or missing name"},
|
|
{in: "@sha256:deadbeef", err: "invalid digest"},
|
|
{in: "@sha256:deadbeef@sha256:deadbeef", err: "invalid digest"},
|
|
}
|
|
for _, tt := range cases {
|
|
t.Run(tt.in, func(t *testing.T) {
|
|
var r Registry
|
|
scheme, n, digest, err := r.parseNameExtended(tt.in)
|
|
if err != nil {
|
|
if tt.err == "" {
|
|
t.Errorf("err = %v; want nil", err)
|
|
} else if !strings.Contains(err.Error(), tt.err) {
|
|
t.Errorf("err = %v; want %q", err, tt.err)
|
|
}
|
|
} else if tt.err != "" {
|
|
t.Errorf("err = nil; want %q", tt.err)
|
|
}
|
|
if err == nil && !n.IsFullyQualified() {
|
|
t.Errorf("name = %q; want fully qualified", n)
|
|
}
|
|
|
|
if scheme != tt.scheme {
|
|
t.Errorf("scheme = %q; want %q", scheme, tt.scheme)
|
|
}
|
|
|
|
// smoke-test name is superset of tt.name
|
|
if !strings.Contains(n.String(), tt.name) {
|
|
t.Errorf("name = %q; want %q", n, tt.name)
|
|
}
|
|
|
|
tt.digest = cmp.Or(tt.digest, (&blob.Digest{}).String())
|
|
if digest.String() != tt.digest {
|
|
t.Errorf("digest = %q; want %q", digest, tt.digest)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestUnlink(t *testing.T) {
|
|
t.Run("found by name", func(t *testing.T) {
|
|
check := testutil.Checker(t)
|
|
|
|
rc, _ := newRegistryClient(t, nil)
|
|
// make a blob and link it
|
|
d := blob.DigestFromBytes("{}")
|
|
err := blob.PutBytes(rc.Cache, d, "{}")
|
|
check(err)
|
|
err = rc.Cache.Link("registry.ollama.ai/library/single:latest", d)
|
|
check(err)
|
|
|
|
// confirm linked
|
|
_, err = rc.ResolveLocal("single")
|
|
check(err)
|
|
|
|
// unlink
|
|
_, err = rc.Unlink("single")
|
|
check(err)
|
|
|
|
// confirm unlinked
|
|
_, err = rc.ResolveLocal("single")
|
|
if !errors.Is(err, fs.ErrNotExist) {
|
|
t.Errorf("err = %v; want fs.ErrNotExist", err)
|
|
}
|
|
})
|
|
t.Run("not found by name", func(t *testing.T) {
|
|
rc, _ := newRegistryClient(t, nil)
|
|
ok, err := rc.Unlink("manifestNotFound")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if ok {
|
|
t.Error("expected not found")
|
|
}
|
|
})
|
|
}
|
|
|
|
// Many tests from here out, in this file are based on a single blob, "abc",
|
|
// with the checksum of its sha256 hash. The checksum is:
|
|
//
|
|
// "abc" -> sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad
|
|
//
|
|
// Using the literal value instead of a constant with fmt.Xprintf calls proved
|
|
// to be the most readable and maintainable approach. The sum is consistently
|
|
// used in the tests and unique so searches do not yield false positives.
|
|
|
|
func checkRequest(t *testing.T, req *http.Request, method, path string) {
|
|
t.Helper()
|
|
if got := req.URL.Path; got != path {
|
|
t.Errorf("URL = %q, want %q", got, path)
|
|
}
|
|
if req.Method != method {
|
|
t.Errorf("Method = %q, want %q", req.Method, method)
|
|
}
|
|
}
|
|
|
|
func newRegistryClient(t *testing.T, upstream http.HandlerFunc) (*Registry, context.Context) {
|
|
s := httptest.NewServer(upstream)
|
|
t.Cleanup(s.Close)
|
|
cache, err := blob.Open(t.TempDir())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
ctx := WithTrace(t.Context(), &Trace{
|
|
Update: func(l *Layer, n int64, err error) {
|
|
t.Log("trace:", l.Digest.Short(), n, err)
|
|
},
|
|
})
|
|
|
|
rc := &Registry{
|
|
Cache: cache,
|
|
HTTPClient: &http.Client{Transport: &http.Transport{
|
|
Dial: func(network, addr string) (net.Conn, error) {
|
|
return net.Dial(network, s.Listener.Addr().String())
|
|
},
|
|
}},
|
|
}
|
|
return rc, ctx
|
|
}
|
|
|
|
func TestPullChunked(t *testing.T) {
|
|
var steps atomic.Int64
|
|
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
switch steps.Add(1) {
|
|
case 1:
|
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
|
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
|
case 2:
|
|
checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
|
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
|
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
|
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
|
|
case 3, 4:
|
|
checkRequest(t, r, "GET", "/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
|
switch rng := r.Header.Get("Range"); rng {
|
|
case "bytes=0-1":
|
|
io.WriteString(w, "ab")
|
|
case "bytes=2-2":
|
|
t.Logf("writing c")
|
|
io.WriteString(w, "c")
|
|
default:
|
|
t.Errorf("unexpected range %q", rng)
|
|
}
|
|
default:
|
|
t.Errorf("unexpected steps %d: %v", steps.Load(), r)
|
|
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
|
}
|
|
})
|
|
|
|
c.ChunkingThreshold = 1 // force chunking
|
|
|
|
err := c.Pull(ctx, "http://o.com/library/abc")
|
|
testutil.Check(t, err)
|
|
|
|
_, err = c.Cache.Resolve("o.com/library/abc:latest")
|
|
testutil.Check(t, err)
|
|
|
|
if g := steps.Load(); g != 4 {
|
|
t.Fatalf("got %d steps, want 4", g)
|
|
}
|
|
}
|
|
|
|
func TestPullCached(t *testing.T) {
|
|
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
|
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
|
})
|
|
|
|
check := testutil.Checker(t)
|
|
|
|
// Premeptively cache the blob
|
|
d, err := blob.ParseDigest("sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
|
check(err)
|
|
err = blob.PutBytes(c.Cache, d, []byte("abc"))
|
|
check(err)
|
|
|
|
// Pull only the manifest, which should be enough to resolve the cached blob
|
|
err = c.Pull(ctx, "http://o.com/library/abc")
|
|
check(err)
|
|
}
|
|
|
|
func TestPullManifestError(t *testing.T) {
|
|
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
|
w.WriteHeader(http.StatusNotFound)
|
|
io.WriteString(w, `{"errors":[{"code":"MANIFEST_UNKNOWN"}]}`)
|
|
})
|
|
|
|
err := c.Pull(ctx, "http://o.com/library/abc")
|
|
if err == nil {
|
|
t.Fatalf("expected error")
|
|
}
|
|
var got *Error
|
|
if !errors.Is(err, ErrModelNotFound) {
|
|
t.Fatalf("err = %v, want %v", got, ErrModelNotFound)
|
|
}
|
|
}
|
|
|
|
func TestPullLayerError(t *testing.T) {
|
|
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
|
io.WriteString(w, `!`)
|
|
})
|
|
|
|
err := c.Pull(ctx, "http://o.com/library/abc")
|
|
if err == nil {
|
|
t.Fatalf("expected error")
|
|
}
|
|
var want *json.SyntaxError
|
|
if !errors.As(err, &want) {
|
|
t.Fatalf("err = %T, want %T", err, want)
|
|
}
|
|
}
|
|
|
|
func TestPullLayerChecksumError(t *testing.T) {
|
|
var step atomic.Int64
|
|
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
switch step.Add(1) {
|
|
case 1:
|
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
|
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
|
case 2:
|
|
checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
|
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
|
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
|
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
|
|
case 3:
|
|
w.WriteHeader(http.StatusNotFound)
|
|
io.WriteString(w, `{"errors":[{"code":"BLOB_UNKNOWN"}]}`)
|
|
case 4:
|
|
io.WriteString(w, "c")
|
|
default:
|
|
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
|
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
|
}
|
|
})
|
|
|
|
c.MaxStreams = 1
|
|
c.ChunkingThreshold = 1 // force chunking
|
|
|
|
var written atomic.Int64
|
|
ctx := WithTrace(t.Context(), &Trace{
|
|
Update: func(l *Layer, n int64, err error) {
|
|
t.Log("trace:", l.Digest.Short(), n, err)
|
|
written.Add(n)
|
|
},
|
|
})
|
|
|
|
err := c.Pull(ctx, "http://o.com/library/abc")
|
|
var got *Error
|
|
if !errors.As(err, &got) || got.Code != "BLOB_UNKNOWN" {
|
|
t.Fatalf("err = %v, want %v", err, got)
|
|
}
|
|
|
|
if g := written.Load(); g != 1 {
|
|
t.Fatalf("wrote %d bytes, want 1", g)
|
|
}
|
|
}
|
|
|
|
func TestPullChunksumStreamError(t *testing.T) {
|
|
var step atomic.Int64
|
|
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
switch step.Add(1) {
|
|
case 1:
|
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
|
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
|
case 2:
|
|
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
|
|
|
// Write one valid chunksum and one invalid chunksum
|
|
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab")) // valid
|
|
fmt.Fprint(w, "sha256:!") // invalid
|
|
case 3:
|
|
io.WriteString(w, "ab")
|
|
default:
|
|
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
|
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
|
}
|
|
})
|
|
|
|
c.ChunkingThreshold = 1 // force chunking
|
|
|
|
got := c.Pull(ctx, "http://o.com/library/abc")
|
|
if !errors.Is(got, ErrIncomplete) {
|
|
t.Fatalf("err = %v, want %v", got, ErrIncomplete)
|
|
}
|
|
}
|
|
|
|
type flushAfterWriter struct {
|
|
w io.Writer
|
|
}
|
|
|
|
func (f *flushAfterWriter) Write(p []byte) (n int, err error) {
|
|
n, err = f.w.Write(p)
|
|
f.w.(http.Flusher).Flush() // panic if not a flusher
|
|
return
|
|
}
|
|
|
|
func TestPullChunksumStreaming(t *testing.T) {
|
|
csr, csw := io.Pipe()
|
|
defer csw.Close()
|
|
|
|
var step atomic.Int64
|
|
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
switch step.Add(1) {
|
|
case 1:
|
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
|
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
|
case 2:
|
|
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
|
fw := &flushAfterWriter{w} // ensure client gets data as it arrives by aggressively flushing
|
|
_, err := io.Copy(fw, csr)
|
|
if err != nil {
|
|
t.Errorf("copy: %v", err)
|
|
}
|
|
case 3:
|
|
io.WriteString(w, "ab")
|
|
case 4:
|
|
io.WriteString(w, "c")
|
|
default:
|
|
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
|
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
|
}
|
|
})
|
|
|
|
c.ChunkingThreshold = 1 // force chunking
|
|
|
|
update := make(chan int64, 1)
|
|
ctx := WithTrace(t.Context(), &Trace{
|
|
Update: func(l *Layer, n int64, err error) {
|
|
t.Log("trace:", l.Digest.Short(), n, err)
|
|
if n > 0 {
|
|
update <- n
|
|
}
|
|
},
|
|
})
|
|
|
|
errc := make(chan error, 1)
|
|
go func() {
|
|
errc <- c.Pull(ctx, "http://o.com/library/abc")
|
|
}()
|
|
|
|
// Send first chunksum and ensure it kicks off work immediately
|
|
fmt.Fprintf(csw, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
|
if g := <-update; g != 2 {
|
|
t.Fatalf("got %d, want 2", g)
|
|
}
|
|
|
|
// now send the second chunksum and ensure it kicks off work immediately
|
|
fmt.Fprintf(csw, "%s 2-2\n", blob.DigestFromBytes("c"))
|
|
if g := <-update; g != 3 {
|
|
t.Fatalf("got %d, want 3", g)
|
|
}
|
|
csw.Close()
|
|
testutil.Check(t, <-errc)
|
|
}
|
|
|
|
func TestPullChunksumsCached(t *testing.T) {
|
|
var step atomic.Int64
|
|
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
switch step.Add(1) {
|
|
case 1:
|
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
|
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
|
case 2:
|
|
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
|
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
|
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
|
|
case 3, 4:
|
|
switch rng := r.Header.Get("Range"); rng {
|
|
case "bytes=0-1":
|
|
io.WriteString(w, "ab")
|
|
case "bytes=2-2":
|
|
io.WriteString(w, "c")
|
|
default:
|
|
t.Errorf("unexpected range %q", rng)
|
|
}
|
|
default:
|
|
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
|
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
|
}
|
|
})
|
|
|
|
c.MaxStreams = 1 // force serial processing of chunksums
|
|
c.ChunkingThreshold = 1 // force chunking
|
|
|
|
ctx, cancel := context.WithCancel(t.Context())
|
|
defer cancel()
|
|
|
|
// Cancel the pull after the first chunksum is processed, but before
|
|
// the second chunksum is processed (which is waiting because
|
|
// MaxStreams=1). This should cause the second chunksum to error out
|
|
// leaving the blob incomplete.
|
|
ctx = WithTrace(ctx, &Trace{
|
|
Update: func(l *Layer, n int64, err error) {
|
|
if n > 0 {
|
|
cancel()
|
|
}
|
|
},
|
|
})
|
|
err := c.Pull(ctx, "http://o.com/library/abc")
|
|
if !errors.Is(err, context.Canceled) {
|
|
t.Fatalf("err = %v, want %v", err, context.Canceled)
|
|
}
|
|
|
|
_, err = c.Cache.Resolve("o.com/library/abc:latest")
|
|
if !errors.Is(err, fs.ErrNotExist) {
|
|
t.Fatalf("err = %v, want nil", err)
|
|
}
|
|
|
|
// Reset state and pull again to ensure the blob chunks that should
|
|
// have been cached are, and the remaining chunk was downloaded, making
|
|
// the blob complete.
|
|
step.Store(0)
|
|
var written atomic.Int64
|
|
var cached atomic.Int64
|
|
ctx = WithTrace(t.Context(), &Trace{
|
|
Update: func(l *Layer, n int64, err error) {
|
|
t.Log("trace:", l.Digest.Short(), n, err)
|
|
if errors.Is(err, ErrCached) {
|
|
cached.Add(n)
|
|
}
|
|
written.Add(n)
|
|
},
|
|
})
|
|
|
|
check := testutil.Checker(t)
|
|
|
|
err = c.Pull(ctx, "http://o.com/library/abc")
|
|
check(err)
|
|
|
|
_, err = c.Cache.Resolve("o.com/library/abc:latest")
|
|
check(err)
|
|
|
|
if g := written.Load(); g != 5 {
|
|
t.Fatalf("wrote %d bytes, want 3", g)
|
|
}
|
|
if g := cached.Load(); g != 2 { // "ab" should have been cached
|
|
t.Fatalf("cached %d bytes, want 5", g)
|
|
}
|
|
}
|