You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1198 lines
31 KiB

3 weeks ago
// Package ollama provides a client for interacting with an Ollama registry
// which pushes and pulls model manifests and layers as defined by the
// [ollama.com/manifest].
package ollama
import (
"bufio"
"bytes"
"cmp"
"context"
"crypto"
"crypto/ed25519"
"crypto/sha256"
"crypto/tls"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"iter"
"log/slog"
"net/http"
"os"
"path/filepath"
"runtime"
"runtime/debug"
"slices"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/internal/names"
_ "embed"
)
// Errors
var (
// ErrModelNotFound is returned when a manifest is not found in the
// cache or registry.
ErrModelNotFound = errors.New("model not found")
// ErrManifestInvalid is returned when a manifest found in a local or
// remote cache is invalid.
ErrManifestInvalid = errors.New("invalid manifest")
// ErrMissingModel is returned when the model part of a name is missing
// or invalid.
ErrNameInvalid = errors.New("invalid or missing name")
// ErrCached is passed to [Trace.PushUpdate] when a layer already
// exists. It is a non-fatal error and is never returned by [Registry.Push].
ErrCached = errors.New("cached")
// ErrIncomplete is returned by [Registry.Pull] when a model pull was
// incomplete due to one or more layer download failures. Users that
// want specific errors should use [WithTrace].
ErrIncomplete = errors.New("incomplete")
)
// Defaults
const (
// DefaultChunkingThreshold is the threshold at which a layer should be
// split up into chunks when downloading.
DefaultChunkingThreshold = 64 << 20
)
var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) {
dir := os.Getenv("OLLAMA_MODELS")
if dir == "" {
home, _ := os.UserHomeDir()
home = cmp.Or(home, ".")
dir = filepath.Join(home, ".ollama", "models")
}
return blob.Open(dir)
})
// DefaultCache returns the default cache used by the registry. It is
// configured from the OLLAMA_MODELS environment variable, or defaults to
// $HOME/.ollama/models, or, if an error occurs obtaining the home directory,
// it uses the current working directory.
func DefaultCache() (*blob.DiskCache, error) {
return defaultCache()
}
// Error is the standard error returned by Ollama APIs. It can represent a
// single or multiple error response.
//
// Single error responses have the following format:
//
// {"code": "optional_code","error":"error message"}
//
// Multiple error responses have the following format:
//
// {"errors": [{"code": "optional_code","message":"error message"}]}
//
// Note, that the error field is used in single error responses, while the
// message field is used in multiple error responses.
//
// In both cases, the code field is optional and may be empty.
type Error struct {
status int `json:"-"` // TODO(bmizerany): remove this
Code string `json:"code"`
Message string `json:"message"`
}
// Temporary reports if the error is temporary (e.g. 5xx status code).
func (e *Error) Temporary() bool {
return e.status >= 500
}
func (e *Error) Error() string {
var b strings.Builder
b.WriteString("registry responded with status ")
b.WriteString(strconv.Itoa(e.status))
if e.Code != "" {
b.WriteString(": code ")
b.WriteString(e.Code)
}
if e.Message != "" {
b.WriteString(": ")
b.WriteString(e.Message)
}
return b.String()
}
func (e *Error) LogValue() slog.Value {
return slog.GroupValue(
slog.Int("status", e.status),
slog.String("code", e.Code),
slog.String("message", e.Message),
)
}
// UnmarshalJSON implements json.Unmarshaler.
func (e *Error) UnmarshalJSON(b []byte) error {
type E Error
var v struct {
// Single error
Code string
Error string
// Multiple errors
Errors []E
}
if err := json.Unmarshal(b, &v); err != nil {
return err
}
if v.Error != "" {
// Single error case
e.Code = v.Code
e.Message = v.Error
return nil
}
if len(v.Errors) == 0 {
return fmt.Errorf("no messages in error response: %s", string(b))
}
*e = Error(v.Errors[0]) // our registry only returns one error.
return nil
}
const DefaultMask = "registry.ollama.ai/library/_:latest"
var defaultMask = func() names.Name {
n := names.Parse(DefaultMask)
if !n.IsFullyQualified() {
panic("default mask is not fully qualified")
}
return n
}()
// CompleteName returns a fully qualified name by merging the given name with
// the default mask. If the name is already fully qualified, it is returned
// unchanged.
func CompleteName(name string) string {
return names.Merge(names.Parse(name), defaultMask).String()
}
// Registry is a client for performing push and pull operations against an
// Ollama registry.
type Registry struct {
// Cache is the cache used to store models. If nil, [DefaultCache] is
// used.
Cache *blob.DiskCache
// UserAgent is the User-Agent header to send with requests to the
// registry. If empty, the User-Agent is determined by HTTPClient.
UserAgent string
// Key is the key used to authenticate with the registry.
//
// Currently, only Ed25519 keys are supported.
Key crypto.PrivateKey
// HTTPClient is the HTTP client used to make requests to the registry.
//
// If nil, [http.DefaultClient] is used.
//
// As a quick note: If a Registry function that makes a call to a URL
// with the "https+insecure" scheme, the client will be cloned and the
// transport will be set to skip TLS verification, unless the client's
// Transport done not have a Clone method with the same signature as
// [http.Transport.Clone], which case, the call will fail.
HTTPClient *http.Client
// MaxStreams is the maximum number of concurrent streams to use when
// pushing or pulling models. If zero, the number of streams is
// determined by [runtime.GOMAXPROCS].
//
// A negative value means no limit.
MaxStreams int
// ChunkingThreshold is the maximum size of a layer to download in a single
// request. If zero, [DefaultChunkingThreshold] is used.
ChunkingThreshold int64
// Mask, if set, is the name used to convert non-fully qualified names
// to fully qualified names.
// If empty, [DefaultMask] is used.
Mask string
// ReadTimeout is the maximum duration for reading the entire request,
// including the body.
// A zero or negative value means there will be no timeout.
ReadTimeout time.Duration
}
func (r *Registry) readTimeout() time.Duration {
if r.ReadTimeout > 0 {
return r.ReadTimeout
}
return 1<<63 - 1 // no timeout, max int
}
func (r *Registry) cache() (*blob.DiskCache, error) {
if r.Cache != nil {
return r.Cache, nil
}
return defaultCache()
}
func (r *Registry) parseName(name string) (names.Name, error) {
mask := defaultMask
if r.Mask != "" {
mask = names.Parse(r.Mask)
}
n := names.Merge(names.Parse(name), mask)
if !n.IsFullyQualified() {
return names.Name{}, fmt.Errorf("%w: %q", ErrNameInvalid, name)
}
return n, nil
}
// DefaultRegistry returns a new Registry configured from the environment. The
// key is read from $HOME/.ollama/id_ed25519, MaxStreams is set to the
// value of OLLAMA_REGISTRY_MAXSTREAMS, and ReadTimeout is set to 30 seconds.
//
// It returns an error if any configuration in the environment is invalid.
func DefaultRegistry() (*Registry, error) {
home, err := os.UserHomeDir()
if err != nil {
return nil, err
}
keyPEM, err := os.ReadFile(filepath.Join(home, ".ollama/id_ed25519"))
if err != nil && errors.Is(err, fs.ErrNotExist) {
return nil, err
}
var rc Registry
rc.ReadTimeout = 30 * time.Second
rc.UserAgent = UserAgent()
rc.Key, err = ssh.ParseRawPrivateKey(keyPEM)
if err != nil {
return nil, err
}
maxStreams := os.Getenv("OLLAMA_REGISTRY_MAXSTREAMS")
if maxStreams != "" {
var err error
rc.MaxStreams, err = strconv.Atoi(maxStreams)
if err != nil {
return nil, fmt.Errorf("invalid OLLAMA_REGISTRY_MAXSTREAMS: %w", err)
}
}
return &rc, nil
}
func UserAgent() string {
buildinfo, _ := debug.ReadBuildInfo()
version := buildinfo.Main.Version
if version == "(devel)" {
// When using `go run .` the version is "(devel)". This is seen
// as an invalid version by ollama.com and so it defaults to
// "needs upgrade" for some requests, such as pulls. These
// checks can be skipped by using the special version "v0.0.0",
// so we set it to that here.
version = "v0.0.0"
}
return fmt.Sprintf("ollama/%s (%s %s) Go/%s",
version,
runtime.GOARCH,
runtime.GOOS,
runtime.Version(),
)
}
func (r *Registry) maxStreams() int {
return cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
}
func (r *Registry) maxChunkingThreshold() int64 {
return cmp.Or(r.ChunkingThreshold, DefaultChunkingThreshold)
}
type PushParams struct {
// From is an optional destination name for the model. If empty, the
// destination name is the same as the source name.
From string
}
// Push pushes the model with the name in the cache to the remote registry.
func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
if p == nil {
p = &PushParams{}
}
c, err := r.cache()
if err != nil {
return err
}
m, err := r.ResolveLocal(cmp.Or(p.From, name))
if err != nil {
return err
}
// Before much else happens, check layers at not null, the blobs exist,
// and the sizes match. This prevents long uploads followed by
// disappointment.
for _, l := range m.Layers {
if l == nil {
return fmt.Errorf("%w: null layer", ErrManifestInvalid)
}
info, err := c.Get(l.Digest)
if err != nil {
return fmt.Errorf("error getting %s: %w", l.Digest.Short(), err)
}
if info.Size != l.Size {
return fmt.Errorf("size mismatch for %s: %d != %d", l.Digest.Short(), info.Size, l.Size)
}
}
t := traceFromContext(ctx)
scheme, n, _, err := r.parseNameExtended(name)
if err != nil {
// This should never happen since ResolveLocal should have
// already validated the name.
panic(err)
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var g errgroup.Group
g.SetLimit(r.maxStreams())
for _, l := range m.Layers {
var progress atomic.Int64
g.Go(func() (err error) {
defer func() { t.update(l, progress.Load(), err) }()
t.update(l, 0, nil)
startURL := fmt.Sprintf("%s://%s/v2/%s/%s/blobs/uploads/?digest=%s",
scheme,
n.Host(),
n.Namespace(),
n.Model(),
l.Digest,
)
res, err := r.send(ctx, "POST", startURL, nil)
if err != nil {
return err
}
res.Body.Close()
f, err := os.Open(c.GetFile(l.Digest))
if err != nil {
return err
}
defer f.Close()
uploadURL := res.Header.Get("Location")
if uploadURL == "" {
t.update(l, l.Size, ErrCached)
return nil
}
req, err := r.newRequest(ctx, "PUT", uploadURL, f)
if err != nil {
return fmt.Errorf("invalid upload URL returned from registry: %q: %w", uploadURL, err)
}
req.ContentLength = l.Size
res, err = sendRequest(r.client(), req)
if err == nil {
res.Body.Close()
}
return err
})
}
if err := g.Wait(); err != nil {
return err
}
// Commit
path := fmt.Sprintf("%s://%s/v2/%s/%s/manifests/%s",
scheme,
n.Host(),
n.Namespace(),
n.Model(),
n.Tag(),
)
res, err := r.send(ctx, "PUT", path, bytes.NewReader(m.Data))
if err == nil {
res.Body.Close()
}
// TODO(bmizerany): add a "commit" trace event
return err
}
// trackingReader is an io.Reader that tracks the number of bytes read and
// calls the update function with the layer, the number of bytes read.
//
// It always calls update with a nil error.
type trackingReader struct {
r io.Reader
update func(n int64, err error) // err is always nil
}
func (r *trackingReader) Read(p []byte) (n int, err error) {
n, err = r.r.Read(p)
r.update(int64(n), nil)
return
}
// Pull pulls the model with the given name from the remote registry into the
// cache.
//
// For layers larger then [Registry.MaxChunkSize], the layer is downloaded in
// chunks of the specified size, and then reassembled and verified. This is
// typically slower than splitting the model up across layers, and is mostly
// utilized for layers of type equal to "application/vnd.ollama.image".
func (r *Registry) Pull(ctx context.Context, name string) error {
m, err := r.Resolve(ctx, name)
if err != nil {
return err
}
// TODO(bmizerany): decide if this should be considered valid. Maybe
// server-side we special case '{}' to have some special meaning? Maybe
// "archiving" a tag (which is how we reason about it in the registry
// already, just with a different twist).
if len(m.Layers) == 0 {
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
}
c, err := r.cache()
if err != nil {
return err
}
// TODO(bmizerany): work to remove the need to do this
layers := m.Layers
if m.Config != nil && m.Config.Digest.IsValid() {
layers = append(layers, m.Config)
}
// Send initial layer trace events to allow clients to have an
// understanding of work to be done before work starts.
var expected int64
t := traceFromContext(ctx)
for _, l := range layers {
t.update(l, 0, nil)
expected += l.Size
}
var g errgroup.Group
g.SetLimit(r.maxStreams())
var completed atomic.Int64
for _, l := range layers {
var received atomic.Int64
update := func(n int64, err error) {
if n == 0 && err == nil {
// Clients expect an update with no progress and no error to mean "starting download".
// This is not the case here,
// so we don't want to send an update in this case.
return
}
completed.Add(n)
t.update(l, received.Add(n), err)
}
info, err := c.Get(l.Digest)
if err == nil && info.Size == l.Size {
update(l.Size, ErrCached)
continue
}
func() (err error) {
defer func() {
if err != nil {
update(0, err)
}
}()
var wg sync.WaitGroup
chunked, err := c.Chunked(l.Digest, l.Size)
if err != nil {
return err
}
defer func() {
// Close the chunked writer when all chunks are
// downloaded.
//
// This is done as a background task in the
// group to allow the next layer to start while
// we wait for the final chunk in this layer to
// complete. It also ensures this is done
// before we exit Pull.
g.Go(func() error {
wg.Wait()
chunked.Close()
return nil
})
}()
for cs, err := range r.chunksums(ctx, name, l) {
if err != nil {
// Note the chunksum stream
// interuption, but do not cancel
// in-flight downloads. We can still
// make progress on them. Once they are
// done, ErrIncomplete will be returned
// below.
update(0, err)
break
}
cacheKey := fmt.Sprintf(
"v1 pull chunksum %s %s %d-%d",
l.Digest,
cs.Digest,
cs.Chunk.Start,
cs.Chunk.End,
)
cacheKeyDigest := blob.DigestFromBytes(cacheKey)
_, err := c.Get(cacheKeyDigest)
if err == nil {
update(cs.Chunk.Size(), ErrCached)
continue
}
wg.Add(1)
g.Go(func() (err error) {
defer func() {
defer wg.Done()
if err != nil {
update(0, err)
}
}()
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
timer := time.AfterFunc(r.readTimeout(), func() {
cancel(fmt.Errorf("%w: downloading %s %d-%d/%d",
context.DeadlineExceeded,
cs.Digest.Short(),
cs.Chunk.Start,
cs.Chunk.End,
l.Size,
))
})
defer timer.Stop()
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
if err != nil {
return err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
tr := &trackingReader{
r: res.Body,
update: func(n int64, err error) {
timer.Reset(r.readTimeout())
update(n, err)
},
}
if err := chunked.Put(cs.Chunk, cs.Digest, tr); err != nil {
return err
}
// Record the downloading of this chunk.
return blob.PutBytes(c, cacheKeyDigest, cacheKey)
})
}
return nil
}()
}
if err := g.Wait(); err != nil {
return err
}
if recv := completed.Load(); recv != expected {
return fmt.Errorf("%w: received %d/%d bytes", ErrIncomplete, recv, expected)
}
md := blob.DigestFromBytes(m.Data)
if err := blob.PutBytes(c, md, m.Data); err != nil {
return err
}
return c.Link(m.Name, md)
}
// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
// before attempting to unlink the model.
func (r *Registry) Unlink(name string) (ok bool, _ error) {
n, err := r.parseName(name)
if err != nil {
return false, err
}
c, err := r.cache()
if err != nil {
return false, err
}
return c.Unlink(n.String())
}
// Manifest represents a [ollama.com/manifest].
type Manifest struct {
Name string `json:"-"` // the canonical name of the model
Data []byte `json:"-"` // the raw data of the manifest
Layers []*Layer `json:"layers"`
// For legacy reasons, we still have to download the config layer.
Config *Layer `json:"config"`
}
// Layer returns the layer with the given
// digest, or nil if not found.
func (m *Manifest) Layer(d blob.Digest) *Layer {
for _, l := range m.Layers {
if l.Digest == d {
return l
}
}
return nil
}
func (m *Manifest) All() iter.Seq[*Layer] {
return func(yield func(*Layer) bool) {
if !yield(m.Config) {
return
}
for _, l := range m.Layers {
if !yield(l) {
return
}
}
}
}
func (m *Manifest) Size() int64 {
var size int64
if m.Config != nil {
size += m.Config.Size
}
for _, l := range m.Layers {
size += l.Size
}
return size
}
// MarshalJSON implements json.Marshaler.
//
// NOTE: It adds an empty config object to the manifest, which is required by
// the registry, but not used by the client. In the future, the config object
// will not be required by the registry and this will should be removed.
func (m Manifest) MarshalJSON() ([]byte, error) {
type M Manifest
v := struct {
M
// This is ignored, mostly, by the registry But, if not
// present, it will cause an error to be returned during the
// last phase of the commit which expects it, but does nothing
// with it. This will be fixed in a future release of
// ollama.com.
Config Layer `json:"config"`
}{
M: M(m),
}
return json.Marshal(v)
}
// unmarshalManifest unmarshals the data into a manifest, and sets the name
// field to the string representation of the name.
//
// It panics if the name is not fully qualified. Callers should ensure the name
// is fully qualified before calling this function.
func unmarshalManifest(n names.Name, data []byte) (*Manifest, error) {
if !n.IsFullyQualified() {
panic(fmt.Sprintf("unmarshalManifest: name is not fully qualified: %s", n.String()))
}
var m Manifest
if err := json.Unmarshal(data, &m); err != nil {
return nil, err
}
m.Name = n.String()
m.Data = data
return &m, nil
}
// Layer is a layer in a model.
type Layer struct {
Digest blob.Digest `json:"digest"`
MediaType string `json:"mediaType"`
Size int64 `json:"size"`
}
// ResolveLocal resolves a name to a Manifest in the local cache.
func (r *Registry) ResolveLocal(name string) (*Manifest, error) {
_, n, d, err := r.parseNameExtended(name)
if err != nil {
return nil, err
}
c, err := r.cache()
if err != nil {
return nil, err
}
if !d.IsValid() {
// No digest, so resolve the manifest by name.
d, err = c.Resolve(n.String())
if err != nil {
return nil, err
}
}
data, err := os.ReadFile(c.GetFile(d))
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return nil, fmt.Errorf("%w: %s", ErrModelNotFound, name)
}
return nil, err
}
m, err := unmarshalManifest(n, data)
if err != nil {
return nil, fmt.Errorf("%s: %w", name, errors.Join(ErrManifestInvalid, err))
}
return m, nil
}
// Resolve resolves a name to a Manifest in the remote registry.
func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) {
scheme, n, d, err := r.parseNameExtended(name)
if err != nil {
return nil, err
}
manifestURL := fmt.Sprintf("%s://%s/v2/%s/%s/manifests/%s", scheme, n.Host(), n.Namespace(), n.Model(), n.Tag())
if d.IsValid() {
manifestURL = fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), d)
}
res, err := r.send(ctx, "GET", manifestURL, nil)
if err != nil {
return nil, err
}
defer res.Body.Close()
data, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
// TODO(bmizerany): return digest here
m, err := unmarshalManifest(n, data)
if err != nil {
return nil, fmt.Errorf("%s: %w", name, errors.Join(ErrManifestInvalid, err))
}
return m, nil
}
type chunksum struct {
URL string
Chunk blob.Chunk
Digest blob.Digest
}
// chunksums returns a sequence of chunksums for the given layer. If the layer is under the
// chunking threshold, a single chunksum is returned that covers the entire layer. If the layer
// is over the chunking threshold, the chunksums are read from the chunksums endpoint.
func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Seq2[chunksum, error] {
return func(yield func(chunksum, error) bool) {
scheme, n, _, err := r.parseNameExtended(name)
if err != nil {
yield(chunksum{}, err)
return
}
if l.Size < r.maxChunkingThreshold() {
// any layer under the threshold should be downloaded
// in one go.
cs := chunksum{
URL: fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s",
scheme,
n.Host(),
n.Namespace(),
n.Model(),
l.Digest,
),
Chunk: blob.Chunk{Start: 0, End: l.Size - 1},
Digest: l.Digest,
}
yield(cs, nil)
return
}
// The response is a sequence of chunksums.
//
// Chunksums are chunks of a larger blob that can be
// downloaded and verified independently.
//
// The chunksums endpoint is a GET request that returns a
// sequence of chunksums in the following format:
//
// > GET /v2/<namespace>/<model>/chunksums/<digest>
//
// < HTTP/1.1 200 OK
// < Content-Location: <blobURL>
// <
// < <digest> <start>-<end>
// < ...
//
// The <blobURL> is the URL to download the chunks from and
// each <digest> is the digest of the chunk, and <start>-<end>
// is the range the chunk in the blob.
//
// Ranges may be used directly in Range headers like
// "bytes=<start>-<end>".
//
// The chunksums returned are guaranteed to be contiguous and
// include all bytes of the layer. If the stream is cut short,
// clients should retry.
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
scheme,
n.Host(),
n.Namespace(),
n.Model(),
l.Digest,
)
req, err := r.newRequest(ctx, "GET", chunksumsURL, nil)
if err != nil {
yield(chunksum{}, err)
return
}
res, err := sendRequest(r.client(), req)
if err != nil {
yield(chunksum{}, err)
return
}
defer res.Body.Close()
if res.StatusCode != 200 {
err := fmt.Errorf("chunksums: unexpected status code %d", res.StatusCode)
yield(chunksum{}, err)
return
}
blobURL := res.Header.Get("Content-Location")
s := bufio.NewScanner(res.Body)
s.Split(bufio.ScanWords)
for {
if !s.Scan() {
if s.Err() != nil {
yield(chunksum{}, s.Err())
}
return
}
d, err := blob.ParseDigest(s.Bytes())
if err != nil {
yield(chunksum{}, fmt.Errorf("invalid digest: %q", s.Bytes()))
return
}
if !s.Scan() {
err := s.Err()
if err == nil {
err = fmt.Errorf("missing chunk range for digest %s", d)
}
yield(chunksum{}, err)
return
}
chunk, err := parseChunk(s.Bytes())
if err != nil {
yield(chunksum{}, fmt.Errorf("invalid chunk range for digest %s: %q", d, s.Bytes()))
return
}
cs := chunksum{
URL: blobURL,
Chunk: chunk,
Digest: d,
}
if !yield(cs, nil) {
return
}
}
}
}
func (r *Registry) client() *http.Client {
if r.HTTPClient != nil {
return r.HTTPClient
}
return http.DefaultClient
}
// newRequest constructs a new request, ready to use, with the given method,
// url, and body, pre-signed with client [Key] and [UserAgent].
func (r *Registry) newRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, method, url, body)
if err != nil {
return nil, err
}
if r.UserAgent != "" {
req.Header.Set("User-Agent", r.UserAgent)
}
if r.Key != nil {
token, err := makeAuthToken(r.Key)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+token)
}
return req, nil
}
// sendRequest makes a request with the given client and request, and returns the
// response if the status code is 200. If the status code is not 200, an Error
// is parsed from the response body and returned. If any other error occurs, it
// is returned.
func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error) {
if r.URL.Scheme == "https+insecure" {
// TODO(bmizerany): clone client.Transport, set
// InsecureSkipVerify, etc.
type cloner interface {
Clone() *http.Transport
}
// Attempt to configure the transport to skip TLS verification
// if we can clone it, otherwise fall through and let the http
// client complain and the scheme being invalid.
x, ok := cmp.Or(c.Transport, http.DefaultTransport).(cloner)
if ok {
tr := x.Clone()
tr.TLSClientConfig = cmp.Or(tr.TLSClientConfig, &tls.Config{})
tr.TLSClientConfig.InsecureSkipVerify = true
cc := *c // shallow copy
cc.Transport = tr
c = &cc
r = r.Clone(r.Context())
r.URL.Scheme = "https"
// fall through
}
}
res, err := c.Do(r)
if err != nil {
return nil, err
}
if res.StatusCode/100 != 2 {
out, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
var re Error
if err := json.Unmarshal(out, &re); err != nil {
// Use the raw body if we can't parse it as an error object.
re.Message = string(out)
}
// coerce MANIFEST_UNKNOWN to ErrManifestNotFound
if strings.EqualFold(re.Code, "MANIFEST_UNKNOWN") {
return nil, ErrModelNotFound
}
re.status = res.StatusCode
return nil, &re
}
return res, nil
}
// send is a convenience method for making a request with newRequest and
// passing it to send with r.client().
func (r *Registry) send(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
req, err := r.newRequest(ctx, method, path, body)
if err != nil {
return nil, err
}
return sendRequest(r.client(), req)
}
// makeAuthToken creates an Ollama auth token for the given private key.
//
// NOTE: This format is OLD, overly complex, and should be replaced. We're
// inheriting it from the original Ollama client and ollama.com
// implementations, so we need to support it for now.
func makeAuthToken(key crypto.PrivateKey) (string, error) {
privKey, _ := key.(*ed25519.PrivateKey)
if privKey == nil {
return "", fmt.Errorf("unsupported private key type: %T", key)
}
url := fmt.Sprintf("https://ollama.com?ts=%d", time.Now().Unix())
// Part 1: the checkData (e.g. the URL with a timestamp)
// Part 2: the public key
pubKeyShort, err := func() ([]byte, error) {
sshPubKey, err := ssh.NewPublicKey(privKey.Public())
if err != nil {
return nil, err
}
pubKeyParts := bytes.Fields(ssh.MarshalAuthorizedKey(sshPubKey))
if len(pubKeyParts) < 2 {
return nil, fmt.Errorf("malformed public key: %q", pubKeyParts)
}
pubKeyShort := pubKeyParts[1]
return pubKeyShort, nil
}()
if err != nil {
return "", err
}
// Part 3: the signature
sig := ed25519.Sign(*privKey, []byte(checkData(url)))
// Assemble the token: <checkData>:<pubKey>:<signature>
var b strings.Builder
io.WriteString(&b, base64.StdEncoding.EncodeToString([]byte(url)))
b.WriteByte(':')
b.Write(pubKeyShort)
b.WriteByte(':')
io.WriteString(&b, base64.StdEncoding.EncodeToString(sig))
return b.String(), nil
}
// The original spec for Ollama tokens was to use the SHA256 of the zero
// string as part of the signature. I'm not sure why that was, but we still
// need it to verify the signature.
var zeroSum = func() string {
sha256sum := sha256.Sum256(nil)
x := base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))
return x
}()
// checkData takes a URL and creates the original string format of the
// data signature that is used by the ollama client to sign requests
func checkData(url string) string {
return fmt.Sprintf("GET,%s,%s", url, zeroSum)
}
type publicError struct {
wrapped error
message string
}
func withPublicMessagef(err error, message string, args ...any) error {
return publicError{wrapped: err, message: fmt.Sprintf(message, args...)}
}
func (e publicError) Error() string { return e.message }
func (e publicError) Unwrap() error { return e.wrapped }
var supportedSchemes = []string{
"http",
"https",
"https+insecure",
}
var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Join(supportedSchemes, ", "))
// parseNameExtended parses and validates an extended name, returning the scheme, name,
// and digest.
//
// If the scheme is empty, scheme will be "https". If an unsupported scheme is
// given, [ErrNameInvalid] wrapped with a display friendly message is returned.
//
// If the digest is invalid, [ErrNameInvalid] wrapped with a display friendly
// message is returned.
//
// If the name is not, once merged with the mask, fully qualified,
// [ErrNameInvalid] wrapped with a display friendly message is returned.
func (r *Registry) parseNameExtended(s string) (scheme string, _ names.Name, _ blob.Digest, _ error) {
scheme, name, digest := splitExtended(s)
scheme = cmp.Or(scheme, "https")
if !slices.Contains(supportedSchemes, scheme) {
err := withPublicMessagef(ErrNameInvalid, "unsupported scheme: %q: %s", scheme, supportedSchemesMessage)
return "", names.Name{}, blob.Digest{}, err
}
var d blob.Digest
if digest != "" {
var err error
d, err = blob.ParseDigest(digest)
if err != nil {
err = withPublicMessagef(ErrNameInvalid, "invalid digest: %q", digest)
return "", names.Name{}, blob.Digest{}, err
}
if name == "" {
// We have can resolve a manifest from a digest only,
// so skip name validation and return the scheme and
// digest.
return scheme, names.Name{}, d, nil
}
}
n, err := r.parseName(name)
if err != nil {
return "", names.Name{}, blob.Digest{}, err
}
return scheme, n, d, nil
}
// splitExtended splits an extended name string into its scheme, name, and digest
// parts.
//
// Examples:
//
// http://ollama.com/bmizerany/smol:latest@digest
// https://ollama.com/bmizerany/smol:latest
// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme.
// model@digest
// @digest
func splitExtended(s string) (scheme, name, digest string) {
i := strings.Index(s, "://")
if i >= 0 {
scheme = s[:i]
s = s[i+3:]
}
i = strings.LastIndex(s, "@")
if i >= 0 {
digest = s[i+1:]
s = s[:i]
}
return scheme, s, digest
}
// parseChunk parses a string in the form "start-end" and returns the Chunk.
func parseChunk[S ~string | ~[]byte](s S) (blob.Chunk, error) {
startPart, endPart, found := strings.Cut(string(s), "-")
if !found {
return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: missing '-'", s)
}
start, err := strconv.ParseInt(startPart, 10, 64)
if err != nil {
return blob.Chunk{}, fmt.Errorf("chunks: invalid start to %q: %v", s, err)
}
end, err := strconv.ParseInt(endPart, 10, 64)
if err != nil {
return blob.Chunk{}, fmt.Errorf("chunks: invalid end to %q: %v", s, err)
}
if start > end {
return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: start > end", s)
}
return blob.Chunk{Start: start, End: end}, nil
}