scripts : Use mmap in torch load (#4202)

* Use mmap in torch load, prefer .bin files when loading

* Revert .bin > .safetensors preference
This commit is contained in:
Galunid 2023-11-25 22:45:02 +01:00 committed by GitHub
parent f837c3a992
commit 1ddb52ec38
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -59,7 +59,7 @@ class Model:
from safetensors import safe_open
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
else:
ctx = contextlib.nullcontext(torch.load(self.dir_model / part_name, map_location="cpu"))
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
with ctx as model_part:
for name in model_part.keys():