You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
418 lines
11 KiB
418 lines
11 KiB
3 weeks ago
|
// Package registry implements an http.Handler for handling local Ollama API
|
||
|
// model management requests. See [Local] for details.
|
||
|
package registry
|
||
|
|
||
|
import (
|
||
|
"cmp"
|
||
|
"context"
|
||
|
"encoding/json"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"log/slog"
|
||
|
"net/http"
|
||
|
"slices"
|
||
|
"strings"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"github.com/ollama/ollama/server/internal/cache/blob"
|
||
|
"github.com/ollama/ollama/server/internal/client/ollama"
|
||
|
"github.com/ollama/ollama/server/internal/internal/backoff"
|
||
|
)
|
||
|
|
||
|
// Local implements an http.Handler for handling local Ollama API model
|
||
|
// management requests, such as pushing, pulling, and deleting models.
|
||
|
//
|
||
|
// It can be arranged for all unknown requests to be passed through to a
|
||
|
// fallback handler, if one is provided.
|
||
|
type Local struct {
|
||
|
Client *ollama.Registry // required
|
||
|
Logger *slog.Logger // required
|
||
|
|
||
|
// Fallback, if set, is used to handle requests that are not handled by
|
||
|
// this handler.
|
||
|
Fallback http.Handler
|
||
|
|
||
|
// Prune, if set, is called to prune the local disk cache after a model
|
||
|
// is deleted.
|
||
|
Prune func() error // optional
|
||
|
}
|
||
|
|
||
|
// serverError is like ollama.Error, but with a Status field for the HTTP
|
||
|
// response code. We want to avoid adding that field to ollama.Error because it
|
||
|
// would always be 0 to clients (we don't want to leak the status code in
|
||
|
// errors), and so it would be confusing to have a field that is always 0.
|
||
|
type serverError struct {
|
||
|
Status int `json:"-"`
|
||
|
|
||
|
// TODO(bmizerany): Decide if we want to keep this and maybe
|
||
|
// bring back later.
|
||
|
Code string `json:"code"`
|
||
|
|
||
|
Message string `json:"error"`
|
||
|
}
|
||
|
|
||
|
func (e serverError) Error() string {
|
||
|
return e.Message
|
||
|
}
|
||
|
|
||
|
// Common API errors
|
||
|
var (
|
||
|
errMethodNotAllowed = &serverError{405, "method_not_allowed", "method not allowed"}
|
||
|
errNotFound = &serverError{404, "not_found", "not found"}
|
||
|
errModelNotFound = &serverError{404, "not_found", "model not found"}
|
||
|
errInternalError = &serverError{500, "internal_error", "internal server error"}
|
||
|
)
|
||
|
|
||
|
type statusCodeRecorder struct {
|
||
|
_status int // use status() to get the status code
|
||
|
http.ResponseWriter
|
||
|
}
|
||
|
|
||
|
func (r *statusCodeRecorder) WriteHeader(status int) {
|
||
|
if r._status == 0 {
|
||
|
r._status = status
|
||
|
r.ResponseWriter.WriteHeader(status)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (r *statusCodeRecorder) Write(b []byte) (int, error) {
|
||
|
r._status = r.status()
|
||
|
return r.ResponseWriter.Write(b)
|
||
|
}
|
||
|
|
||
|
var (
|
||
|
_ http.ResponseWriter = (*statusCodeRecorder)(nil)
|
||
|
_ http.CloseNotifier = (*statusCodeRecorder)(nil)
|
||
|
_ http.Flusher = (*statusCodeRecorder)(nil)
|
||
|
)
|
||
|
|
||
|
// CloseNotify implements the http.CloseNotifier interface, for Gin. Remove with Gin.
|
||
|
//
|
||
|
// It panics if the underlying ResponseWriter is not a CloseNotifier.
|
||
|
func (r *statusCodeRecorder) CloseNotify() <-chan bool {
|
||
|
return r.ResponseWriter.(http.CloseNotifier).CloseNotify()
|
||
|
}
|
||
|
|
||
|
// Flush implements the http.Flusher interface, for Gin. Remove with Gin.
|
||
|
//
|
||
|
// It panics if the underlying ResponseWriter is not a Flusher.
|
||
|
func (r *statusCodeRecorder) Flush() {
|
||
|
r.ResponseWriter.(http.Flusher).Flush()
|
||
|
}
|
||
|
|
||
|
func (r *statusCodeRecorder) status() int {
|
||
|
return cmp.Or(r._status, 200)
|
||
|
}
|
||
|
|
||
|
func (s *Local) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||
|
rec := &statusCodeRecorder{ResponseWriter: w}
|
||
|
s.serveHTTP(rec, r)
|
||
|
}
|
||
|
|
||
|
func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
|
||
|
var errattr slog.Attr
|
||
|
proxied, err := func() (bool, error) {
|
||
|
switch r.URL.Path {
|
||
|
case "/api/delete":
|
||
|
return false, s.handleDelete(rec, r)
|
||
|
case "/api/pull":
|
||
|
return false, s.handlePull(rec, r)
|
||
|
default:
|
||
|
if s.Fallback != nil {
|
||
|
s.Fallback.ServeHTTP(rec, r)
|
||
|
return true, nil
|
||
|
}
|
||
|
return false, errNotFound
|
||
|
}
|
||
|
}()
|
||
|
if err != nil {
|
||
|
// We always log the error, so fill in the error log attribute
|
||
|
errattr = slog.String("error", err.Error())
|
||
|
|
||
|
var e *serverError
|
||
|
switch {
|
||
|
case errors.As(err, &e):
|
||
|
case errors.Is(err, ollama.ErrNameInvalid):
|
||
|
e = &serverError{400, "bad_request", err.Error()}
|
||
|
default:
|
||
|
e = errInternalError
|
||
|
}
|
||
|
|
||
|
data, err := json.Marshal(e)
|
||
|
if err != nil {
|
||
|
// unreachable
|
||
|
panic(err)
|
||
|
}
|
||
|
rec.Header().Set("Content-Type", "application/json")
|
||
|
rec.WriteHeader(e.Status)
|
||
|
rec.Write(data)
|
||
|
|
||
|
// fallthrough to log
|
||
|
}
|
||
|
|
||
|
if !proxied {
|
||
|
// we're only responsible for logging if we handled the request
|
||
|
var level slog.Level
|
||
|
if rec.status() >= 500 {
|
||
|
level = slog.LevelError
|
||
|
} else if rec.status() >= 400 {
|
||
|
level = slog.LevelWarn
|
||
|
}
|
||
|
|
||
|
s.Logger.LogAttrs(r.Context(), level, "http",
|
||
|
errattr, // report first in line to make it easy to find
|
||
|
|
||
|
// TODO(bmizerany): Write a test to ensure that we are logging
|
||
|
// all of this correctly. That also goes for the level+error
|
||
|
// logic above.
|
||
|
slog.Int("status", rec.status()),
|
||
|
slog.String("method", r.Method),
|
||
|
slog.String("path", r.URL.Path),
|
||
|
slog.Int64("content-length", r.ContentLength),
|
||
|
slog.String("remote", r.RemoteAddr),
|
||
|
slog.String("proto", r.Proto),
|
||
|
slog.String("query", r.URL.RawQuery),
|
||
|
)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type params struct {
|
||
|
// DeprecatedName is the name of the model to push, pull, or delete,
|
||
|
// but is deprecated. New clients should use [Model] instead.
|
||
|
//
|
||
|
// Use [model()] to get the model name for both old and new API requests.
|
||
|
DeprecatedName string `json:"name"`
|
||
|
|
||
|
// Model is the name of the model to push, pull, or delete.
|
||
|
//
|
||
|
// Use [model()] to get the model name for both old and new API requests.
|
||
|
Model string `json:"model"`
|
||
|
|
||
|
// AllowNonTLS is a flag that indicates a client using HTTP
|
||
|
// is doing so, deliberately.
|
||
|
//
|
||
|
// Deprecated: This field is ignored and only present for this
|
||
|
// deprecation message. It should be removed in a future release.
|
||
|
//
|
||
|
// Users can just use http or https+insecure to show intent to
|
||
|
// communicate they want to do insecure things, without awkward and
|
||
|
// confusing flags such as this.
|
||
|
AllowNonTLS bool `json:"insecure"`
|
||
|
|
||
|
// Stream, if true, will make the server send progress updates in a
|
||
|
// streaming of JSON objects. If false, the server will send a single
|
||
|
// JSON object with the final status as "success", or an error object
|
||
|
// if an error occurred.
|
||
|
//
|
||
|
// Unfortunately, this API was designed to be a bit awkward. Stream is
|
||
|
// defined to default to true if not present, so we need a way to check
|
||
|
// if the client decisively set it to false. So, we use a pointer to a
|
||
|
// bool. Gross.
|
||
|
//
|
||
|
// Use [stream()] to get the correct value for this field.
|
||
|
Stream *bool `json:"stream"`
|
||
|
}
|
||
|
|
||
|
// model returns the model name for both old and new API requests.
|
||
|
func (p params) model() string {
|
||
|
return cmp.Or(p.Model, p.DeprecatedName)
|
||
|
}
|
||
|
|
||
|
func (p params) stream() bool {
|
||
|
if p.Stream == nil {
|
||
|
return true
|
||
|
}
|
||
|
return *p.Stream
|
||
|
}
|
||
|
|
||
|
func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
|
||
|
if r.Method != "DELETE" {
|
||
|
return errMethodNotAllowed
|
||
|
}
|
||
|
p, err := decodeUserJSON[*params](r.Body)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
ok, err := s.Client.Unlink(p.model())
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if !ok {
|
||
|
return errModelNotFound
|
||
|
}
|
||
|
if s.Prune != nil {
|
||
|
return s.Prune()
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
type progressUpdateJSON struct {
|
||
|
Error string `json:"error,omitempty,omitzero"`
|
||
|
Status string `json:"status,omitempty,omitzero"`
|
||
|
Digest blob.Digest `json:"digest,omitempty,omitzero"`
|
||
|
Total int64 `json:"total,omitempty,omitzero"`
|
||
|
Completed int64 `json:"completed,omitempty,omitzero"`
|
||
|
}
|
||
|
|
||
|
func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
||
|
if r.Method != "POST" {
|
||
|
return errMethodNotAllowed
|
||
|
}
|
||
|
|
||
|
p, err := decodeUserJSON[*params](r.Body)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
enc := json.NewEncoder(w)
|
||
|
if !p.stream() {
|
||
|
if err := s.Client.Pull(r.Context(), p.model()); err != nil {
|
||
|
if errors.Is(err, ollama.ErrModelNotFound) {
|
||
|
return errModelNotFound
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
enc.Encode(progressUpdateJSON{Status: "success"})
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
var mu sync.Mutex
|
||
|
var progress []progressUpdateJSON
|
||
|
flushProgress := func() {
|
||
|
mu.Lock()
|
||
|
progress := slices.Clone(progress) // make a copy and release lock before encoding to the wire
|
||
|
mu.Unlock()
|
||
|
for _, p := range progress {
|
||
|
enc.Encode(p)
|
||
|
}
|
||
|
fl, _ := w.(http.Flusher)
|
||
|
if fl != nil {
|
||
|
fl.Flush()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
t := time.NewTicker(1<<63 - 1) // "unstarted" timer
|
||
|
start := sync.OnceFunc(func() {
|
||
|
flushProgress() // flush initial state
|
||
|
t.Reset(100 * time.Millisecond)
|
||
|
})
|
||
|
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
|
||
|
Update: func(l *ollama.Layer, n int64, err error) {
|
||
|
if err != nil && !errors.Is(err, ollama.ErrCached) {
|
||
|
s.Logger.Error("pulling", "model", p.model(), "error", err)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func() {
|
||
|
mu.Lock()
|
||
|
defer mu.Unlock()
|
||
|
for i, p := range progress {
|
||
|
if p.Digest == l.Digest {
|
||
|
progress[i].Completed = n
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
progress = append(progress, progressUpdateJSON{
|
||
|
Digest: l.Digest,
|
||
|
Total: l.Size,
|
||
|
})
|
||
|
}()
|
||
|
|
||
|
// Block flushing progress updates until every
|
||
|
// layer is accounted for. Clients depend on a
|
||
|
// complete model size to calculate progress
|
||
|
// correctly; if they use an incomplete total,
|
||
|
// progress indicators would erratically jump
|
||
|
// as new layers are registered.
|
||
|
start()
|
||
|
},
|
||
|
})
|
||
|
|
||
|
done := make(chan error, 1)
|
||
|
go func() (err error) {
|
||
|
defer func() { done <- err }()
|
||
|
for _, err := range backoff.Loop(ctx, 3*time.Second) {
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
err := s.Client.Pull(ctx, p.model())
|
||
|
if canRetry(err) {
|
||
|
continue
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
return nil
|
||
|
}()
|
||
|
|
||
|
enc.Encode(progressUpdateJSON{Status: "pulling manifest"})
|
||
|
for {
|
||
|
select {
|
||
|
case <-t.C:
|
||
|
flushProgress()
|
||
|
case err := <-done:
|
||
|
flushProgress()
|
||
|
if err != nil {
|
||
|
if errors.Is(err, ollama.ErrModelNotFound) {
|
||
|
return &serverError{
|
||
|
Status: 404,
|
||
|
Code: "not_found",
|
||
|
Message: fmt.Sprintf("model %q not found", p.model()),
|
||
|
}
|
||
|
} else {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Emulate old client pull progress (for now):
|
||
|
enc.Encode(progressUpdateJSON{Status: "verifying sha256 digest"})
|
||
|
enc.Encode(progressUpdateJSON{Status: "writing manifest"})
|
||
|
enc.Encode(progressUpdateJSON{Status: "success"})
|
||
|
return nil
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func decodeUserJSON[T any](r io.Reader) (T, error) {
|
||
|
var v T
|
||
|
err := json.NewDecoder(r).Decode(&v)
|
||
|
if err == nil {
|
||
|
return v, nil
|
||
|
}
|
||
|
var zero T
|
||
|
|
||
|
// Not sure why, but I can't seem to be able to use:
|
||
|
//
|
||
|
// errors.As(err, &json.UnmarshalTypeError{})
|
||
|
//
|
||
|
// This is working fine in stdlib, so I'm not sure what rules changed
|
||
|
// and why this no longer works here. So, we do it the verbose way.
|
||
|
var a *json.UnmarshalTypeError
|
||
|
var b *json.SyntaxError
|
||
|
if errors.As(err, &a) || errors.As(err, &b) {
|
||
|
err = &serverError{Status: 400, Message: err.Error(), Code: "bad_request"}
|
||
|
}
|
||
|
if errors.Is(err, io.EOF) {
|
||
|
err = &serverError{Status: 400, Message: "empty request body", Code: "bad_request"}
|
||
|
}
|
||
|
return zero, err
|
||
|
}
|
||
|
|
||
|
func canRetry(err error) bool {
|
||
|
if err == nil {
|
||
|
return false
|
||
|
}
|
||
|
var oe *ollama.Error
|
||
|
if errors.As(err, &oe) {
|
||
|
return oe.Temporary()
|
||
|
}
|
||
|
s := err.Error()
|
||
|
return cmp.Or(
|
||
|
errors.Is(err, context.DeadlineExceeded),
|
||
|
strings.Contains(s, "unreachable"),
|
||
|
strings.Contains(s, "no route to host"),
|
||
|
strings.Contains(s, "connection reset by peer"),
|
||
|
)
|
||
|
}
|