diff --git a/flake.nix b/flake.nix index 7723357af..433d3d942 100644 --- a/flake.nix +++ b/flake.nix @@ -35,6 +35,20 @@ ); pkgs = import nixpkgs { inherit system; }; nativeBuildInputs = with pkgs; [ cmake ninja pkg-config ]; + cudatoolkit_joined = with pkgs; symlinkJoin { + # HACK(Green-Sky): nix currently has issues with cmake findcudatoolkit + # see https://github.com/NixOS/nixpkgs/issues/224291 + # copied from jaxlib + name = "${cudaPackages.cudatoolkit.name}-merged"; + paths = [ + cudaPackages.cudatoolkit.lib + cudaPackages.cudatoolkit.out + ] ++ lib.optionals (lib.versionOlder cudaPackages.cudatoolkit.version "11") [ + # for some reason some of the required libs are in the targets/x86_64-linux + # directory; not sure why but this works around it + "${cudaPackages.cudatoolkit}/targets/${system}" + ]; + }; llama-python = pkgs.python3.withPackages (ps: with ps; [ numpy sentencepiece ]); postPatch = '' @@ -70,6 +84,13 @@ "-DLLAMA_CLBLAST=ON" ]; }; + packages.cuda = pkgs.stdenv.mkDerivation { + inherit name src meta postPatch nativeBuildInputs postInstall; + buildInputs = with pkgs; buildInputs ++ [ cudatoolkit_joined ]; + cmakeFlags = cmakeFlags ++ [ + "-DLLAMA_CUBLAS=ON" + ]; + }; packages.rocm = pkgs.stdenv.mkDerivation { inherit name src meta postPatch nativeBuildInputs postInstall; buildInputs = with pkgs; buildInputs ++ [ hip hipblas rocblas ];