Parser for safetensors

There is a new format to store weights of neural networks advocated by huggingface, called safetensors. I have encountered when I wanted to download phi-2 model to Transformers.jl (see this pull request Add Phi model by chengchingwen · Pull Request #168 · chengchingwen/Transformers.jl · GitHub). I have written a simple non-performant loader (it is doing too much seeking at the moment) and posting it here, as I have not made a proper repo. JSON3 would be much nicer to use, but I am not familiar with it. The format is described here Safetensors

using JSON

	_gettype(s, name)

	Julia type of the tensor from the string name
function _gettype(s::AbstractString, name="")
	s == "F16" && return(Float16)
	s == "F32" && return(Float32)
	s == "F64" && return(Float64)
	s == "B" && return(Bool)
	s == "U8" && return(UInt8)
	s == "I8" && return(Int8)
	s == "I16" && return(Int16)
	s == "I32" && return(Int32)
	s == "I64" && return(Int64)
	s == "BF16" && error("BFloat16 is not supported")
	name = isempty(name) ? name : " of the tensor "*name
	error("unknown type $(s)", name)

_byteoftype(::Type{T}) where {T<:Union{Bool, UInt8, Int8}} = 1
_byteoftype(::Type{T}) where {T<:Union{Int16, Float16}} = 2
_byteoftype(::Type{T}) where {T<:Union{Int32, Float32}} = 4
_byteoftype(::Type{T}) where {T<:Union{Int64, Float64}} = 8

	readtensor!(fio::IO, header::Dict, name::String, header_length; seek_to_start = true)
	readtensor!(fio::IO, T, shape, start, stop, name="", header_length; seek_to_start = true)

	reads tensor `name` from the file `fio`. 
	`seek_to_start = true` means that seek(fio, start) will be called to ensure that reading 
	starts from correct position 
function readtensor!(fio::IO, header::Dict, name::String, header_length; seek_to_start = true)
	entry = header[name]
	T = _gettype(entry["dtype"], name)
	start = Int(entry["data_offsets"][1]) + header_length
	stop = Int(entry["data_offsets"][2]) + header_length
	shape = tuple(Int.(entry["shape"])...)
	readtensor!(fio, T, shape, start, stop, name; seek_to_start)

function readtensor!(fio::IO, T::Type, shape::NTuple{N,<:Integer}, start::Integer, stop::Integer, name=""; seek_to_start = true) where {N}
	seek_to_start && seek(fio, start)
	n = stop - start
	if _byteoftype(T)*prod(shape) != n
		s = isempty(name) ? "" : "of tensor "*name
		error("length of the stored data",s," does not corresponds to shape of the tensor")
	x = Vector{T}(undef, prod(shape))
	read!(fio, x)
	x = reshape(x, reverse(shape))
	if length(shape) == 2
		x = transpose(x)
	if length(shape) > 2
		warn("higher dimensional tensor $(name) untested")

function names_without_metadata(header)
	filter(s -> s !== "__metadata__", collect(keys(header)))


	return a sorted list of pairs (name_of_tensor, start)
function starts_of_tensors(header)
	ks = names_without_metadata(header)
	starts = map(ks) do k 
		k => Int(header[k]["data_offsets"][1])
	sort!(starts, lt = (i,j) -> i[2] < j[2])

	is_continuous(header, starts = starts_of_tensors(header))

	return true if tensors in header are correctly aligned and can be read sequentially (which they should)
function is_continuous(header, starts = starts_of_tensors(header))
	i = 0 
	for (k, start) in starts 
		start != i && return(false)
		i = Int(header[k]["data_offsets"][2])

	header, header_length = load_header(fio::IO)

	loads the header of a stream containing safetensor
function load_header(fio::IO)
	seek(fio, 0)
	n = read(fio, Int64) # first read the length of the header
	s = read(fio, n) # then read the header
	header = JSON.parse(String(s))
	return(header, 8 + n)

function load_tensors_scattered(fio::IO, header, tensors, header_length; seek_to_start = true)
	Dict(map(k -> k => readtensor!(fio, header, k, header_length; seek_to_start), tensors))

function load_tensors(filename::AbstractString)
	open(filename,"r") do fio
		header, header_length = load_header(fio)
		starts = starts_of_tensors(header)
		seek_to_start = !is_continuous(header, starts)
		tensors = first.(starts)
		load_tensors_scattered(fio, header, tensors, header_length; seek_to_start)

filename = "Downloads/model-00002-of-00002.safetensors"


I had a small bug in the above code. It is now tested on parameters of llama7b model, for which I was able to obtain parameters stored in pickle and safetensor.

1 Like

This would be great to have as a package, let me know if you need a home for it.

1 Like

Home would be nice. I have at the moment put the code to Transformers.jl, such that it can be used immediately, but I was thinking that package would be nice to have.

Where would you like to have it?

1 Like

I was thinking something in FluxML maybe? Thus far nobody else has asked for SafeTensors outside of Transformers.jl use cases IIRC, but if they do we could pull the relevant code out into a standalone repo.

1 Like


I might have access to that org. I will try to polish the package.




may I ask you to create me an empty repository for the loader of safetensors with the name SafeTensors.jl and invite me to it? I have finally created the package and importantly basic tests checking correctness of loading things from the python’s safetensors.

I can upload it there.

Thanks a lot