You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
60 lines
1.0 KiB
60 lines
1.0 KiB
3 weeks ago
|
package convert
|
||
|
|
||
|
import (
|
||
|
"io"
|
||
|
"io/fs"
|
||
|
"strings"
|
||
|
|
||
|
"github.com/nlpodyssey/gopickle/pytorch"
|
||
|
"github.com/nlpodyssey/gopickle/types"
|
||
|
)
|
||
|
|
||
|
func parseTorch(fsys fs.FS, replacer *strings.Replacer, ps ...string) ([]Tensor, error) {
|
||
|
var ts []Tensor
|
||
|
for _, p := range ps {
|
||
|
pt, err := pytorch.Load(p)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
for _, k := range pt.(*types.Dict).Keys() {
|
||
|
t := pt.(*types.Dict).MustGet(k)
|
||
|
|
||
|
var shape []uint64
|
||
|
for dim := range t.(*pytorch.Tensor).Size {
|
||
|
shape = append(shape, uint64(dim))
|
||
|
}
|
||
|
|
||
|
ts = append(ts, torch{
|
||
|
storage: t.(*pytorch.Tensor).Source,
|
||
|
tensorBase: &tensorBase{
|
||
|
name: replacer.Replace(k.(string)),
|
||
|
shape: shape,
|
||
|
},
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return ts, nil
|
||
|
}
|
||
|
|
||
|
type torch struct {
|
||
|
storage pytorch.StorageInterface
|
||
|
*tensorBase
|
||
|
}
|
||
|
|
||
|
func (t torch) Clone() Tensor {
|
||
|
return torch{
|
||
|
storage: t.storage,
|
||
|
tensorBase: &tensorBase{
|
||
|
name: t.name,
|
||
|
shape: t.shape,
|
||
|
repacker: t.repacker,
|
||
|
},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (pt torch) WriteTo(w io.Writer) (int64, error) {
|
||
|
return 0, nil
|
||
|
}
|