package llama4 import ( "math" "github.com/ollama/ollama/fs" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" ) type VisionAttention struct { Query *nn.Linear `gguf:"attn_q"` Key *nn.Linear `gguf:"attn_k"` Value *nn.Linear `gguf:"attn_v"` Output *nn.Linear `gguf:"attn_output"` } // applyVisionRotaryEmbedding applies 2D rotary embedding to the input tensor. // This is equivalent to the Pytorch implmentation using half rotations: // // cos, sin = torch.cos(freqs), torch.sin(freqs) // cos = cos.unsqueeze(-1) // sin = sin.unsqueeze(-1) // t = t.reshape(*t.shape[:-1], -1, 2) // t_out = (t * cos) + (_rotate_half(t) * sin) // t_out = t_out.flatten(3) // // Which is equivalent to the Pytorch implementation using complex numbers: // // t_ = torch.view_as_complex(t.float().reshape(*t.shape[:-1], -1, 2)) // freqs_ci = reshape_for_broadcast(freqs_ci=freq_cis, t=t_) # freqs_ci[:,:,None,:] // freqs_ci = freqs_ci.to(t_.device) // t_out = torch.view_as_real(t_ * freqs_ci).flatten(3) // // Due to the 1) the dimensional and 2) the datatype limitations of current backends, // we need to use a different approach to achieve the same result. func applyVisionRotaryEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor { width, height, channels, tiles := t.Dim(0), t.Dim(1), t.Dim(2), t.Dim(3) t = t.Reshape(ctx, 2, t.Dim(0)/2, t.Dim(1)*t.Dim(2)*t.Dim(3)) // t1 = t[..., 0::2] t1 := t.View(ctx, 0, 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx) t1 = t1.Reshape(ctx, width/2, height, channels, tiles) // t2 = t[..., 1::2] t2 := t.View(ctx, t.Stride(0), 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx) t2 = t2.Reshape(ctx, width/2, height, channels, tiles) // cos_out = torch.stack((t1 * cos, t2 * cos), dim=-1) cosOut := t1.Mul(ctx, cos).Concat(ctx, t2.Mul(ctx, cos), 0) cosOut = cosOut.Reshape(ctx, cosOut.Dim(0)/2, 2, cosOut.Dim(1)*cosOut.Dim(2)*cosOut.Dim(3)) cosOut = cosOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) cosOut = cosOut.Reshape(ctx, width, height, channels, tiles) // sin_out = torch.stack((-t2 * sin, t1 * sin), dim=-1) sinOut := t2.Neg(ctx).Mul(ctx, sin).Concat(ctx, t1.Mul(ctx, sin), 0) sinOut = sinOut.Reshape(ctx, sinOut.Dim(0)/2, 2, sinOut.Dim(1)*sinOut.Dim(2)*sinOut.Dim(3)) sinOut = sinOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) sinOut = sinOut.Reshape(ctx, width, height, channels, tiles) return cosOut.Add(ctx, sinOut) } func (sa *VisionAttention) Forward(ctx ml.Context, hiddenState, cos, sin ml.Tensor, opts *VisionOptions) ml.Tensor { headDim := opts.hiddenSize / opts.numHeads query := sa.Query.Forward(ctx, hiddenState) key := sa.Key.Forward(ctx, hiddenState) value := sa.Value.Forward(ctx, hiddenState) query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), query.Dim(2)) key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), key.Dim(2)) value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), value.Dim(2)) query = applyVisionRotaryEmbedding(ctx, query, cos, sin) key = applyVisionRotaryEmbedding(ctx, key, cos, sin) attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), nil) attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), attention.Dim(3)) return sa.Output.Forward(ctx, attention) } type VisionMLP struct { FC1 *nn.Linear `gguf:"fc1"` FC2 *nn.Linear `gguf:"fc2"` } func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionOptions) ml.Tensor { hiddenStates = mlp.FC1.Forward(ctx, hiddenStates).GELU(ctx) hiddenStates = mlp.FC2.Forward(ctx, hiddenStates) return hiddenStates } type VisionLayer struct { InputLayerNorm *nn.LayerNorm `gguf:"attn_norm"` *VisionAttention PostAttentionNorm *nn.LayerNorm `gguf:"ffn_norm"` *VisionMLP `gguf:"mlp"` } func (e *VisionLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionOptions) ml.Tensor { residual := hiddenStates // self attention hiddenStates = e.InputLayerNorm.Forward(ctx, hiddenStates, opts.eps) hiddenStates = e.VisionAttention.Forward(ctx, hiddenStates, cos, sin, opts) hiddenStates = hiddenStates.Add(ctx, residual) // MLP residual = hiddenStates hiddenStates = e.PostAttentionNorm.Forward(ctx, hiddenStates, opts.eps) hiddenStates = e.VisionMLP.Forward(ctx, hiddenStates, opts) hiddenStates = hiddenStates.Add(ctx, residual) return hiddenStates } type VisionAdapter struct { FC1 *nn.Linear `gguf:"mlp.fc1"` FC2 *nn.Linear `gguf:"mlp.fc2"` } func (a *VisionAdapter) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionOptions) ml.Tensor { patches := hiddenStates.Dim(1) patchSize := int(math.Sqrt(float64(patches))) hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), patchSize, patchSize, hiddenStates.Dim(2)) channels, width, height, tiles := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2), hiddenStates.Dim(3) channels, width = int(float32(channels)/opts.pixelShuffleRatio), int(float32(width)*opts.pixelShuffleRatio) hiddenStates = hiddenStates.Reshape(ctx, channels, width, height, tiles) hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) channels, height = int(float32(channels)/opts.pixelShuffleRatio), int(float32(height)*opts.pixelShuffleRatio) hiddenStates = hiddenStates.Reshape(ctx, channels, width, height, tiles) hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) hiddenStates = hiddenStates.Reshape(ctx, channels, width*height, tiles) hiddenStates = a.FC1.Forward(ctx, hiddenStates).GELU(ctx) hiddenStates = a.FC2.Forward(ctx, hiddenStates).GELU(ctx) return hiddenStates } type VisionOptions struct { hiddenSize, numHeads int imageSize, patchSize int ropeTheta float32 eps float32 pixelShuffleRatio float32 } type PatchEmbedding struct { *nn.Linear } func (p *PatchEmbedding) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionOptions) ml.Tensor { kernel := ctx.Input().Empty(ml.DTypeF32, opts.patchSize, opts.patchSize, hiddenStates.Dim(2)) hiddenStates = kernel.IM2Col(ctx, hiddenStates, opts.patchSize, opts.patchSize, 0, 0, 1, 1) hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), hiddenStates.Dim(1)*hiddenStates.Dim(2), hiddenStates.Dim(3)) return p.Linear.Forward(ctx, hiddenStates) } type VisionModel struct { Layers []VisionLayer `gguf:"blk"` *PatchEmbedding `gguf:"patch_embedding"` ClassEmbedding ml.Tensor `gguf:"class_embedding"` PositionalEmbedding ml.Tensor `gguf:"positional_embedding_vlm"` LayerNormPre *nn.LayerNorm `gguf:"layernorm_pre"` LayerNormPost *nn.LayerNorm `gguf:"layernorm_post"` *VisionAdapter `gguf:"vision_adapter"` *VisionOptions } func newVisionModel(c fs.Config) *VisionModel { return &VisionModel{ Layers: make([]VisionLayer, c.Uint("vision.block_count")), VisionOptions: &VisionOptions{ hiddenSize: int(c.Uint("vision.embedding_length")), numHeads: int(c.Uint("vision.attention.head_count")), imageSize: int(c.Uint("vision.image_size")), patchSize: int(c.Uint("vision.patch_size")), ropeTheta: float32(c.Float("vision.rope.freq_base")), eps: c.Float("vision.layer_norm_epsilon"), pixelShuffleRatio: float32(c.Float("vision.pixel_shuffle_ratio")), }, } } func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionOptions) hiddenStates = hiddenStates.Concat(ctx, m.ClassEmbedding.Repeat(ctx, 2, hiddenStates.Dim(2)), 1) hiddenStates = hiddenStates.Add(ctx, m.PositionalEmbedding) hiddenStates = m.LayerNormPre.Forward(ctx, hiddenStates, m.eps) cos, sin := m.rotaryEmbedding(ctx) for _, layer := range m.Layers { hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionOptions) } hiddenStates = m.LayerNormPost.Forward(ctx, hiddenStates, m.eps) hiddenStates = hiddenStates.Pad(ctx, 0, -1, 0, 0) hiddenStates = m.VisionAdapter.Forward(ctx, hiddenStates, m.VisionOptions) return hiddenStates } // floorDiv is a helper function to perform floor division. This mimics PyTorch's div(round_mode='floor') function // which in turn mimics Python's // operator. func floorDiv[T int | int16 | int32 | int64 | uint | uint16 | uint32 | uint64](a, b T) T { if b == 0 { panic("division by zero") } if (a >= 0 && b > 0) || (a <= 0 && b < 0) || a%b == 0 { return a / b } return a/b - 1 } func (m *VisionModel) rotaryEmbedding(ctx ml.Context) (ml.Tensor, ml.Tensor) { patchesPerSide := m.imageSize / m.patchSize numPatches := patchesPerSide*patchesPerSide + 1 headDim := m.hiddenSize / m.numHeads freqDim := headDim / 2 freqs := make([]float32, numPatches*freqDim) for i := range numPatches - 1 { for j := 0; j < freqDim; j += 2 { positionX := i*freqDim/2 + j/2 positionY := (i+numPatches)*freqDim/2 + j/2 ropeFreq := math.Pow(float64(m.ropeTheta), float64(j)*2/float64(headDim)) freqs[positionX] = float32(float64(1+i-floorDiv(i, patchesPerSide)*patchesPerSide) / ropeFreq) freqs[positionY] = float32(float64(1+floorDiv(i, patchesPerSide)) / ropeFreq) } } ropeFreqs, err := ctx.Input().FromFloatSlice(freqs, freqDim/2, numPatches, 2) if err != nil { panic(err) } ropeFreqs = ropeFreqs.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) ropeFreqs = ropeFreqs.Reshape(ctx, freqDim, 1, numPatches) return ropeFreqs.Cos(ctx), ropeFreqs.Sin(ctx) }