diff --git a/Llama3.java b/Llama3.java index 25a495f..85e9730 100755 --- a/Llama3.java +++ b/Llama3.java @@ -3,6 +3,7 @@ //PREVIEW //COMPILE_OPTIONS --add-modules=jdk.incubator.vector //RUNTIME_OPTIONS --add-modules=jdk.incubator.vector +//MAIN com.llama4j.Llama3 // Practical Llama 3 (and 3.1) inference in a single Java file // Author: Alfonso² Peterssen @@ -22,19 +23,29 @@ import jdk.incubator.vector.FloatVector; import jdk.incubator.vector.VectorOperators; import jdk.incubator.vector.VectorSpecies; +import sun.misc.Unsafe; +import java.io.BufferedInputStream; +import java.io.BufferedOutputStream; +import java.io.FileInputStream; +import java.io.FileOutputStream; import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import java.io.PrintStream; import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; import java.lang.foreign.ValueLayout; +import java.lang.reflect.Field; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.FloatBuffer; import java.nio.channels.FileChannel; import java.nio.charset.StandardCharsets; +import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.nio.file.StandardOpenOption; import java.util.*; import java.util.concurrent.TimeUnit; import java.util.function.IntConsumer; @@ -80,9 +91,11 @@ static void runInteractive(Llama model, Sampler sampler, Options options) { Llama.State state = null; List conversationTokens = new ArrayList<>(); ChatFormat chatFormat = new ChatFormat(model.tokenizer()); + int stateCacheSize = 0; conversationTokens.add(chatFormat.beginOfText); if (options.systemPrompt() != null) { conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt()))); + stateCacheSize = conversationTokens.size(); } int startPosition = 0; Scanner in = new Scanner(System.in); @@ -97,9 +110,27 @@ static void runInteractive(Llama model, Sampler sampler, Options options) { state = model.createNewState(); } conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, userText))); + if (stateCacheSize == 0) { + stateCacheSize = conversationTokens.size(); + } conversationTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); + if (startPosition == 0 && options.readStateCache() && options.stateCachePath() != null) { + try (FileInputStream fis = new FileInputStream(options.stateCachePath().toFile()); + BufferedInputStream bis = new BufferedInputStream(fis)) { + System.out.printf("Read cached tokens in %s%n", options.stateCachePath()); + startPosition = state.deserialize(bis, model.configuration(), model.tokenizer(), conversationTokens, options.echo()); + } catch (IOException e) { + throw new RuntimeException("IO-exception while reading state-cache " + options.stateCachePath(), e); + } + + } + Path pathStateCache = (startPosition == 0 && options.writeStateCache() && options.stateCachePath() != null) ? options.stateCachePath() : null; Set stopTokens = chatFormat.getStopTokens(); - List responseTokens = Llama.generateTokens(model, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), sampler, options.echo(), token -> { + if (options.echo()) { + dumpStatistics(model, startPosition, stopTokens); + } + List responseTokens = Llama.generateTokens(model, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), pathStateCache, stateCacheSize, + stopTokens, options.maxTokens(), sampler, options.echo(), token -> { if (options.stream()) { if (!model.tokenizer().isSpecialToken(token)) { System.out.print(model.tokenizer().decode(List.of(token))); @@ -129,16 +160,37 @@ static void runInstructOnce(Llama model, Sampler sampler, Options options) { Llama.State state = model.createNewState(); ChatFormat chatFormat = new ChatFormat(model.tokenizer()); + int stateCacheSize = 0; List promptTokens = new ArrayList<>(); promptTokens.add(chatFormat.beginOfText); if (options.systemPrompt() != null) { promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt()))); + stateCacheSize = promptTokens.size(); } promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, options.prompt()))); + if (stateCacheSize == 0) { + stateCacheSize = promptTokens.size(); + } promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); + int startPosition = 0; + if (options.readStateCache() && options.stateCachePath() != null) { + try (FileInputStream fis = new FileInputStream(options.stateCachePath().toFile()); + BufferedInputStream bis = new BufferedInputStream(fis)) { + System.out.printf("Read cached tokens in %s%n", options.stateCachePath()); + startPosition = state.deserialize(bis, model.configuration(), model.tokenizer(), promptTokens, options.echo()); + } catch (IOException e) { + throw new RuntimeException("IO-exception while reading state-cache " + options.stateCachePath(), e); + } + + } + Path pathStateCache = (startPosition == 0 && options.writeStateCache() && options.stateCachePath() != null) ? options.stateCachePath() : null; Set stopTokens = chatFormat.getStopTokens(); - List responseTokens = Llama.generateTokens(model, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), token -> { + if (options.echo()) { + dumpStatistics(model, startPosition, stopTokens); + } + List responseTokens = Llama.generateTokens(model, state, startPosition, promptTokens, pathStateCache, stateCacheSize, + stopTokens, options.maxTokens(), sampler, options.echo(), token -> { if (options.stream()) { if (!model.tokenizer().isSpecialToken(token)) { System.out.print(model.tokenizer().decode(List.of(token))); @@ -154,8 +206,23 @@ static void runInstructOnce(Llama model, Sampler sampler, Options options) { } } + private static void dumpStatistics(Llama model, int startPosition, Set stopTokens) { + var config = model.configuration(); + int numLayers = config.numberOfLayers; + int dim = config.dim; + int headSize = config.headSize; + int kvDim = (config.dim * config.numberOfKeyValueHeads) / config.numberOfHeads; + int vocSize = model.configuration().vocabularySize; + System.out.printf("model: numLayers=%d, dim=%d, numHeads=%d, headSize=%d, kvDim=%d, vocSize=%d%n", + numLayers, dim, config.numberOfHeads, headSize, kvDim, vocSize); + System.out.printf("startPosition=%d, stopTokens=%s%n", startPosition, stopTokens); + } + record Options(Path modelPath, String prompt, String systemPrompt, boolean interactive, - float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo) { + float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo, + Path stateCachePath, boolean readStateCache, boolean writeStateCache) { + + static final int DEFAULT_MAX_TOKENS = 512; Options { require(modelPath != null, "Missing argument: --model is required"); @@ -185,16 +252,19 @@ static void printUsage(PrintStream out) { out.println(" --temperature, -temp temperature in [0,inf], default 0.1"); out.println(" --top-p p value in top-p (nucleus) sampling in [0,1] default 0.95"); out.println(" --seed random seed, default System.nanoTime()"); - out.println(" --max-tokens, -n number of steps to run for < 0 = limited by context length, default 512"); + out.println(" --max-tokens, -n number of steps to run for < 0 = limited by context length, default " + DEFAULT_MAX_TOKENS); out.println(" --stream print tokens during generation; may cause encoding artifacts for non ASCII text, default true"); out.println(" --echo print ALL tokens to stderr, if true, recommended to set --stream=false, default false"); + out.println(" --state-cache optional, path to state-cache file (.ggsc)"); + out.println(" --read-state-cache read state-cache file"); + out.println(" --write-state-cache write state-cache file"); out.println(); out.println("Examples:"); - out.println(" jbang Llama3.java --model llama3-8b-q4_0.gguf --prompt \"Tell me a joke\""); - out.println(" jbang Llama3.java --model llama3-8b-q4_0.gguf --system-prompt \"Reply concisely, in French\" --prompt \"Who was Marie Curie?\""); - out.println(" jbang Llama3.java --model llama3-8b-q4_0.gguf --system-prompt \"Answer concisely\" --chat"); - out.println(" jbang Llama3.java --model llama3-8b-q4_0.gguf --chat"); - out.println(" jbang Llama3.java --model llama3-8b-q4_0.gguf --prompt \"Print 5 emojis\" --stream=false"); + out.println(" jbang Llama3.java --model llama3.2-1b-q4_0.gguf --prompt \"Tell me a joke\""); + out.println(" jbang Llama3.java --model llama3.2-1b-q4_0.gguf --system-prompt \"Reply concisely, in French\" --prompt \"Who was Marie Curie?\""); + out.println(" jbang Llama3.java --model llama3.2-1b-q4_0.gguf --system-prompt \"Answer concisely\" --chat"); + out.println(" jbang Llama3.java --model llama3.2-1b-q4_0.gguf --chat"); + out.println(" jbang Llama3.java --model llama3.2-1b-q4_0.gguf --prompt \"Print 5 emojis\" --stream=false"); } static Options parseOptions(String[] args) { @@ -205,10 +275,13 @@ static Options parseOptions(String[] args) { Path modelPath = null; long seed = System.nanoTime(); // Keep max context length small for low-memory devices. - int maxTokens = 512; + int maxTokens = DEFAULT_MAX_TOKENS; boolean interactive = false; boolean stream = true; boolean echo = false; + Path stateCachePath = null; + boolean readStateCache = false; + boolean writeStateCache = false; for (int i = 0; i < args.length; i++) { String optionName = args[i]; @@ -241,18 +314,26 @@ static Options parseOptions(String[] args) { case "--max-tokens", "-n" -> maxTokens = Integer.parseInt(nextArg); case "--stream" -> stream = Boolean.parseBoolean(nextArg); case "--echo" -> echo = Boolean.parseBoolean(nextArg); + case "--state-cache" -> stateCachePath = Paths.get(nextArg); + case "--read-state-cache" -> readStateCache = Boolean.parseBoolean(nextArg); + case "--write-state-cache" -> writeStateCache = Boolean.parseBoolean(nextArg); default -> require(false, "Unknown option: %s", optionName); } } } } - return new Options(modelPath, prompt, systemPrompt, interactive, temperature, topp, seed, maxTokens, stream, echo); + return new Options(modelPath, prompt, systemPrompt, interactive, temperature, topp, seed, maxTokens, stream, echo, + stateCachePath, readStateCache, writeStateCache); } } public static void main(String[] args) throws IOException { Options options = Options.parseOptions(args); - Llama model = ModelLoader.loadModel(options.modelPath(), options.maxTokens()); + Llama model = AOT.tryUsePreLoaded(options.modelPath(), options.maxTokens()); + if (model == null) { + // No compatible preloaded model found, fallback to fully parse and load the specified file. + model = ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true); + } Sampler sampler = selectSampler(model.configuration().vocabularySize, options.temperature(), options.topp(), options.seed()); if (options.interactive()) { runInteractive(model, sampler, options); @@ -272,10 +353,18 @@ final class GGUF { private int alignment; private int metadata_kv_count; // uint64_t private Map metadata; + + public Map getTensorInfos() { + return tensorInfos; + } + private Map tensorInfos; + private long tensorDataOffset; - private MemorySegment tensorData; // memory mapped tensor data - private Map tensorEntries; + + public long getTensorDataOffset() { + return tensorDataOffset; + } public Map getMetadata() { return metadata; @@ -286,10 +375,6 @@ public Map getMetadata() { private final ByteBuffer BB_4 = ByteBuffer.allocate(Integer.BYTES).order(ByteOrder.LITTLE_ENDIAN); private final ByteBuffer BB_8 = ByteBuffer.allocate(Long.BYTES).order(ByteOrder.LITTLE_ENDIAN); - public Map getTensorEntries() { - return tensorEntries; - } - public static GGUF loadModel(Path modelPath) throws IOException { try (FileChannel fileChannel = FileChannel.open(modelPath); var ignored = Timer.log("Parse " + modelPath)) { @@ -374,16 +459,20 @@ private void loadModelImpl(FileChannel fileChannel) throws IOException { // should be padded to `ALIGNMENT` bytes. // uint8_t tensor_data[]; this.tensorDataOffset = fileChannel.position(); + } + + public static Map loadTensors(FileChannel fileChannel, long tensorDataOffset, Map tensorInfos) throws IOException { Arena arena = Arena.ofAuto(); - this.tensorData = fileChannel.map(FileChannel.MapMode.READ_ONLY, tensorDataOffset, fileChannel.size() - tensorDataOffset, arena); - this.tensorEntries = HashMap.newHashMap(tensorInfos.size()); - for (Map.Entry entry : tensorInfos.entrySet()) { - GGUF.GGUFTensorInfo ti = entry.getValue(); + MemorySegment tensorData = fileChannel.map(FileChannel.MapMode.READ_ONLY, tensorDataOffset, fileChannel.size() - tensorDataOffset, arena); + Map tensorEntries = HashMap.newHashMap(tensorInfos.size()); + for (Map.Entry entry : tensorInfos.entrySet()) { + GGUFTensorInfo ti = entry.getValue(); int numberOfElements = FloatTensor.numberOfElements(ti.dimensions()); int sizeInBytes = Math.toIntExact(ti.ggmlType().byteSizeFor(numberOfElements)); MemorySegment memorySegment = tensorData.asSlice(ti.offset(), sizeInBytes); tensorEntries.put(ti.name(), new GGMLTensorEntry(tensorData, ti.name(), ti.ggmlType(), ti.dimensions(), memorySegment)); } + return tensorEntries; } public record GGUFTensorInfo(String name, int[] dimensions, GGMLType ggmlType, long offset) { @@ -652,20 +741,21 @@ private static Vocabulary loadVocabulary(Map metadata) { return new Vocabulary(tokens, null); } - public static Llama loadModel(Path ggufPath, int contextLength) throws IOException { + public static Llama loadModel(Path ggufPath, int contextLength, boolean loadWeights) throws IOException { + GGUF gguf = GGUF.loadModel(ggufPath); + FileChannel fileChannel = FileChannel.open(ggufPath, StandardOpenOption.READ); + return loadModel(ggufPath, fileChannel, gguf, contextLength, loadWeights); + } + + public static Llama loadModel(Path ggufPath, FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) throws IOException { try (var ignored = Timer.log("Load LlaMa model")) { - GGUF gguf = GGUF.loadModel(ggufPath); Map metadata = gguf.getMetadata(); - Vocabulary vocabulary = loadVocabulary(metadata); Tokenizer tokenizer = createTokenizer(metadata, vocabulary); - int modelContextLength = (int) metadata.get("llama.context_length"); - if (contextLength < 0 || modelContextLength < contextLength) { - contextLength = modelContextLength; - } - + String modelName = ggufPath.getFileName().toString(); Llama.Configuration config = new Llama.Configuration( + modelName, (int) metadata.get("llama.embedding_length"), (int) metadata.get("llama.feed_forward_length"), (int) metadata.get("llama.block_count"), @@ -676,42 +766,52 @@ public static Llama loadModel(Path ggufPath, int contextLength) throws IOExcepti : (int) metadata.get("llama.attention.head_count"), vocabulary.size(), - contextLength, - false, + (int) metadata.get("llama.context_length"), (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f), (float) metadata.getOrDefault("llama.rope.freq_base", 10000f) - ); - - boolean ropeScaling = "Meta-Llama-3.1".equals(metadata.get("general.basename")); - float scaleFactor = 8; - float loFreqFactor = 1; - float hiFreqFactor = 3; - int oldContextLength = 8192; - Pair ropeFreqs = RoPE.precomputeFreqsCis(config.contextLength, config.headSize, config.ropeTheta, - ropeScaling, scaleFactor, loFreqFactor, hiFreqFactor, oldContextLength); - float[] ropeFreqsReal = ropeFreqs.first(); - float[] ropeFreqsImag = ropeFreqs.second(); - - Map tensorEntries = gguf.getTensorEntries(); - Llama.Weights qw = new Llama.Weights( - loadQuantized(tensorEntries.get("token_embd.weight")), - loadArrayOfFloatBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), - loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), - loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), - loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_v.weight")), - loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), - loadArrayOfFloatBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), - loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), // w1 - loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // w2 - loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // w3 - toFloatBuffer(tensorEntries.get("output_norm.weight")), - FloatBuffer.wrap(ropeFreqsReal), - FloatBuffer.wrap(ropeFreqsImag), - loadQuantized(tensorEntries.get("output.weight")) - ); - - return new Llama(config, tokenizer, qw); - } + ).withContextLength(contextLength); + + Llama.Weights weights = null; + if (loadWeights) { + Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos()); + weights = loadWeights(tensorEntries, config); + } + return new Llama(config, tokenizer, weights); + } + } + + static Llama.Weights loadWeights(Map tensorEntries, Llama.Configuration config) { + boolean ropeScaling = tensorEntries.containsKey("rope_freqs"); + float scaleFactor = 8; + float loFreqFactor = 1; + float hiFreqFactor = 3; + int oldContextLength = 8192; + Pair ropeFreqs = RoPE.precomputeFreqsCis(config.contextLength, config.headSize, config.ropeTheta, + ropeScaling, scaleFactor, loFreqFactor, hiFreqFactor, oldContextLength); + float[] ropeFreqsReal = ropeFreqs.first(); + float[] ropeFreqsImag = ropeFreqs.second(); + + GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight"); + Llama.Weights qw = new Llama.Weights( + loadQuantized(tokenEmbeddings), + loadArrayOfFloatBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), + loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), + loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), + loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_v.weight")), + loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), + loadArrayOfFloatBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), + loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), // w1 + loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // w2 + loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // w3 + toFloatBuffer(tensorEntries.get("output_norm.weight")), + FloatBuffer.wrap(ropeFreqsReal), + FloatBuffer.wrap(ropeFreqsImag), + // If "output.weight" is not present then the embedding weights are tied/shared with the decoder. + // This is commonly referred as "tie word embeddings". + loadQuantized(tensorEntries.getOrDefault("output.weight", tokenEmbeddings)) + ); + + return qw; } private static Tokenizer createTokenizer(Map metadata, Vocabulary vocabulary) { @@ -785,6 +885,7 @@ public State createNewState() { } public static final class Configuration { + public final String modelGGUFName; // model GGUF-name public final int dim; // transformer dimension public final int hiddenDim; // for ffn layers public final int numberOfLayers; // number of layers @@ -792,12 +893,19 @@ public static final class Configuration { public final int numberOfKeyValueHeads; // number of key/value heads (can be < query heads because of multiquery) public final int vocabularySize; // vocabulary size, usually 256 (byte-level) public final int contextLength; // max sequence length - public final boolean sharedWeights; public final float rmsNormEps; public final float ropeTheta; public final int headSize; - public Configuration(int dim, int hiddenDim, int numberOfLayers, int numberOfHeads, int numberOfKeyValueHeads, int vocabularySize, int contextLength, boolean sharedWeights, float rmsNormEps, float ropeTheta) { + Configuration withContextLength(int newContextLength) { + if (newContextLength < 0) { + return this; // no change + } + return new Configuration(this.modelGGUFName, this.dim, this.hiddenDim, this.numberOfLayers, this.numberOfHeads, this.numberOfKeyValueHeads, this.vocabularySize, newContextLength, this.rmsNormEps, this.ropeTheta); + } + + public Configuration(String modelGGUFName, int dim, int hiddenDim, int numberOfLayers, int numberOfHeads, int numberOfKeyValueHeads, int vocabularySize, int contextLength, float rmsNormEps, float ropeTheta) { + this.modelGGUFName = modelGGUFName; this.dim = dim; this.hiddenDim = hiddenDim; this.numberOfLayers = numberOfLayers; @@ -805,7 +913,6 @@ public Configuration(int dim, int hiddenDim, int numberOfLayers, int numberOfHea this.numberOfKeyValueHeads = numberOfKeyValueHeads; this.vocabularySize = vocabularySize; this.contextLength = contextLength; - this.sharedWeights = sharedWeights; this.rmsNormEps = rmsNormEps; this.ropeTheta = ropeTheta; this.headSize = dim / numberOfHeads; @@ -854,6 +961,12 @@ public Weights(FloatTensor token_embedding_table, FloatBuffer[] rms_att_weight, } public static final class State { + /** "GGSC": GGUF state cache */ + private static final int MAGIC_STATE_CACHE =0x47475343; + /** "Version 2 */ + private static final int MAGIC_STATE_VERSION =0x02; + + private final int kvDim; // current wave of activations public final FloatTensor x; // activation at current time stamp (dim,) @@ -883,10 +996,137 @@ public static final class State { this.v = ArrayFloatTensor.allocate(config.dim); this.att = ArrayFloatTensor.allocate(config.numberOfHeads, config.contextLength); this.logits = ArrayFloatTensor.allocate(config.vocabularySize); - int kvDim = (config.dim * config.numberOfKeyValueHeads) / config.numberOfHeads; + this.kvDim = (config.dim * config.numberOfKeyValueHeads) / config.numberOfHeads; this.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength, kvDim)).limit(config.numberOfLayers).toArray(FloatTensor[]::new); this.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength, kvDim)).limit(config.numberOfLayers).toArray(FloatTensor[]::new); } + + public void serialize(OutputStream os, Configuration config, List tokens) throws IOException { + if (!(keyCache[0] instanceof ArrayFloatTensor)) { + throw new UnsupportedOperationException("keyCache has unexpected type: " + keyCache.getClass()); + } + if (!(valueCache[0] instanceof ArrayFloatTensor)) { + throw new UnsupportedOperationException("valueCache has unexpected type: " + valueCache.getClass()); + } + byte[] bufName = config.modelGGUFName.getBytes(StandardCharsets.UTF_8); + int[] sizes = { 24, bufName.length, tokens.size() * 4, kvDim * 4}; + ByteBuffer bb = ByteBuffer.allocate(Arrays.stream(sizes).max().getAsInt()); + bb.putInt(0, MAGIC_STATE_CACHE); + bb.putInt(4, MAGIC_STATE_VERSION); + bb.putInt(8, bufName.length); + bb.putInt(12, tokens.size()); + bb.putInt(16, kvDim); + bb.putInt(20, config.numberOfLayers); + os.write(bb.array(), 0, 24); + os.write(bufName); + for (int i = 0; i < tokens.size(); i++) { + bb.putInt(4 * i, tokens.get(i).intValue()); + } + os.write(bb.array(), 0, 4 * tokens.size()); + for (int nLayer = 0; nLayer < config.numberOfLayers; nLayer++) { + for (int i = 0; i < tokens.size(); i++) { + for (int k = 0; k < kvDim; k++) { + bb.putFloat(4 * k, keyCache[nLayer].getFloat(k + i * kvDim)); + } + os.write(bb.array(), 0, 4 * kvDim); + for (int k = 0; k < kvDim; k++) { + bb.putFloat(4 * k, valueCache[nLayer].getFloat(k + i * kvDim)); + } + os.write(bb.array(), 0, 4 * kvDim); + } + } + } + + public int deserialize(InputStream is, Configuration config, Tokenizer tokenizer, List tokens, boolean echo) throws IOException { + ByteBuffer bb = ByteBuffer.allocate(24); + read(is, bb, 24); + check(bb, 0, MAGIC_STATE_CACHE, "MAGIC_STATE_CACHE"); + check(bb, 4, MAGIC_STATE_VERSION, "MAGIC_STATE_VERSION"); + int nameActualLength = bb.getInt(8); + int numCachedTokens = bb.getInt(12); + check(bb, 16, kvDim, "kvDim"); + check(bb, 20, config.numberOfLayers, "numberOfLayers"); + + int[] sizes = { nameActualLength, numCachedTokens * 4, kvDim * 4 }; + bb = ByteBuffer.allocate(Arrays.stream(sizes).max().getAsInt()); + + String nameExpected = config.modelGGUFName; + byte[] bufNameExpected = nameExpected.getBytes(StandardCharsets.UTF_8); + read(is, bb, nameActualLength); + String nameActual = new String(bb.array(), 0, nameActualLength); + if (!nameActual.equals(nameExpected)) { + throw new IllegalArgumentException(String.format("Invalid model-name in state-cache: expected='%s', actual='%s'", nameExpected, nameActual)); + } + read(is, bb, 4 * numCachedTokens); + int numTokensRead = 0; + final List cachedTokens = new ArrayList<>(); + for (int i = 0; i < numCachedTokens && i < tokens.size(); i++) { + final int actual = bb.getInt(4 * i); + cachedTokens.add(actual); + if (i != numTokensRead) { + continue; + } + int expected = tokens.get(i).intValue(); + if (actual == expected) { + numTokensRead++; + } else if (i < tokens.size() - 2) { + System.out.printf("Reused %d of %d tokens in cache-file, actual=%d ('%s'), expected=%d ('%s')%n", + numTokensRead, tokens.size(), + expected, tokenizer.decode(Collections.singletonList(expected)), actual, tokenizer.decode(Collections.singletonList(actual))); + } + } + if (echo) { + System.out.println("Current tokens: " + tokens); + System.out.println("Cached tokens: " + cachedTokens); + } + for (int nLayer = 0; nLayer < config.numberOfLayers; nLayer++) { + for (int i = 0; i < numCachedTokens; i++) { + read(is, bb, 4 * kvDim); + if (i < numTokensRead) { + for (int k = 0; k < kvDim; k++) { + keyCache[nLayer].setFloat(k + i * kvDim, bb.getFloat(4 * k)); + } + } + read(is, bb, 4 * kvDim); + if (i < numTokensRead) { + for (int k = 0; k < kvDim; k++) { + valueCache[nLayer].setFloat(k + i * kvDim, bb.getFloat(4 * k)); + } + } + } + } + return numTokensRead; + } + + static void check(ByteBuffer bb, int offset, int expected, String name) throws IOException { + final int actual = bb.getInt(offset); + if (actual != expected) { + throw new IOException(String.format("Unexpected value '%s': actual=0x%x, expected=0x%x", name, expected, actual)); + } + } + + static void check(ByteBuffer bb, int offset, int expected, String name, Tokenizer tokenizer) throws IOException { + final int actual = bb.getInt(offset); + if (actual != expected) { + throw new IOException(String.format("Unexpected value '%s': actual=%d ('%s'), expected=%d ('%s')", + name, expected, tokenizer.decode(Collections.singletonList(expected)), actual, tokenizer.decode(Collections.singletonList(actual)))); + } + } + + static void read(InputStream is, ByteBuffer bb, int size) throws IOException { + byte[] buf = bb.array(); + int offset = 0; + while (offset < size) { + int len = is.read(buf, offset, size - offset); + if (len == -1) { + break; + } + offset += len; + } + if (offset < size) { + throw new IOException(String.format("Unexpected end of stream: offset=%d, size=%d", offset, size)); + } + } } static void rmsnorm(FloatTensor out, FloatTensor x, FloatBuffer weight, int size, float rmsNormEps) { @@ -900,10 +1140,10 @@ static void rmsnorm(FloatTensor out, FloatTensor x, FloatBuffer weight, int size out.mapWithIndexInPlace(0, size, (value, index) -> weight.get(index) * (finalss * x.getFloat(index))); } - static FloatTensor forward(Llama model, Llama.State state, int token, int position) { + static FloatTensor forward(Llama model, State state, int token, int position) { // a few convenience variables - Llama.Configuration config = model.configuration(); - Llama.Weights weights = model.weights(); + Configuration config = model.configuration(); + Weights weights = model.weights(); int dim = config.dim; int headSize = config.headSize; int kvDim = (config.dim * config.numberOfKeyValueHeads) / config.numberOfHeads; @@ -1035,6 +1275,8 @@ static FloatTensor forward(Llama model, Llama.State state, int token, int positi * @param state state of the model e.g. key/value caches ... this is mutated by this call * @param startPosition start prompt ingestion + inference at this position in the context e.g. useful if state was kept across calls (chained generation). 0 implies run with no previous context. * @param promptTokens prompt tokens to ingest, all the prompt tokens will be ingested, given there's enough capacity left in the context + * @param stateCachePath optioal path of a state-cache file to be written + * @param stateCacheSize size of stat-cache (number of tokens to be cached) in case of a state-cache to be written * @param stopTokens set of tokens that abort generation during inference, stop tokens do not affect prompt ingestion * @param maxTokens maximum number of tokens (can go up to {@link Configuration#contextLength context length} * if this value is negative or greater than {@link Configuration#contextLength context length} @@ -1043,7 +1285,8 @@ static FloatTensor forward(Llama model, Llama.State state, int token, int positi * @param onTokenGenerated callback, if non-null, it's called every time a token is inferred e.g. it's not called when ingesting prompt tokens * @return list of generated/inferred tokens, including the stop token, if any e.g. does not include any token from the prompt */ - public static List generateTokens(Llama model, Llama.State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, + public static List generateTokens(Llama model, Llama.State state, int startPosition, List promptTokens, Path stateCachePath, int stateCacheSize, + Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated) { long startNanos = System.nanoTime(); if (maxTokens < 0 || model.configuration().contextLength < maxTokens) { @@ -1055,6 +1298,16 @@ public static List generateTokens(Llama model, Llama.State state, int s int promptIndex = 0; for (int position = startPosition; position < maxTokens; ++position) { forward(model, state, token, position); + if (promptIndex == stateCacheSize && stateCachePath != null) { + System.out.println(String.format("Write %d of %d tokens into %s", stateCacheSize, promptTokens.size(), stateCachePath)); + final List tokenToCache = promptTokens.subList(0, stateCacheSize); + try (FileOutputStream fos = new FileOutputStream(stateCachePath.toFile()); + BufferedOutputStream bos = new BufferedOutputStream(fos)) { + state.serialize(bos, model.configuration, tokenToCache); + } catch (IOException e) { + throw new RuntimeException(String.format("IO-error while writing %s", stateCachePath), e); + } + } if (promptIndex < promptTokens.size()) { // Force-pick token from prompt. nextToken = promptTokens.get(promptIndex++); @@ -1427,11 +1680,35 @@ private static boolean isPowerOf2(int n) { * e.g. can represent a sequence of quantized floats. */ abstract class FloatTensor { - static final ValueLayout.OfFloat JAVA_FLOAT_LE = ValueLayout.JAVA_FLOAT.withOrder(ByteOrder.LITTLE_ENDIAN); - static final ValueLayout.OfShort JAVA_SHORT_LE = ValueLayout.JAVA_SHORT.withOrder(ByteOrder.LITTLE_ENDIAN); - static final boolean USE_VECTOR_API = Boolean.parseBoolean(System.getProperty("llama.VectorAPI", "true")); + // static final ValueLayout.OfFloat JAVA_FLOAT_LE = ValueLayout.JAVA_FLOAT.withOrder(ByteOrder.LITTLE_ENDIAN); + // static final ValueLayout.OfShort JAVA_SHORT_LE = ValueLayout.JAVA_SHORT.withOrder(ByteOrder.LITTLE_ENDIAN); + + + // The use of Unsafe in this file is a temporary workaround to support native-image. + static final Unsafe UNSAFE; + + static { + try { + Field f = Unsafe.class.getDeclaredField("theUnsafe"); + f.setAccessible(true); + UNSAFE = (Unsafe) f.get(null); + } catch (NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException(e); + } + } + + static short readShort(MemorySegment memorySegment, long offset) { + // The MemorySegment.get* methods should be used instead. + return UNSAFE.getShort(memorySegment.address() + offset); + } + + static byte readByte(MemorySegment memorySegment, long offset) { + // The MemorySegment.get* methods should be used instead. + return UNSAFE.getByte(memorySegment.address() + offset); + } + // Preferred vector size for the fast multiplication routines. // (Apple Silicon) NEON only supports up-to 128bit vectors. static final VectorSpecies F_SPECIES = FloatVector.SPECIES_PREFERRED.vectorBitSize() == 128 ? FloatVector.SPECIES_128 : FloatVector.SPECIES_256; @@ -1626,13 +1903,13 @@ public float getFloat(int index) { assert 0 <= index && index < size; int blockIndex = index / GGMLType.Q4_0.getBlockSize(); int blockOffset = blockIndex * GGMLType.Q4_0.getTypeSize(); - float scale = Float.float16ToFloat(memorySegment.get(JAVA_SHORT_LE, blockOffset)); + float scale = Float.float16ToFloat(readShort(memorySegment, blockOffset)); byte quant; int modIndex = index % GGMLType.Q4_0.getBlockSize(); if (modIndex < GGMLType.Q4_0.getBlockSize() / 2) { - quant = (byte) (memorySegment.get(ValueLayout.JAVA_BYTE, blockOffset + Float16.BYTES + modIndex) & 0x0F); + quant = (byte) (readByte(memorySegment, blockOffset + Float16.BYTES + modIndex) & 0x0F); } else { - quant = (byte) ((memorySegment.get(ValueLayout.JAVA_BYTE, blockOffset + Float16.BYTES + modIndex - GGMLType.Q4_0.getBlockSize() / 2) >>> 4) & 0x0F); + quant = (byte) ((readByte(memorySegment, blockOffset + Float16.BYTES + modIndex - GGMLType.Q4_0.getBlockSize() / 2) >>> 4) & 0x0F); } quant -= 8; return quant * scale; @@ -1641,13 +1918,13 @@ public float getFloat(int index) { @Override public float dot(int thisOffset, FloatTensor that, int thatOffset, int size) { if (FloatTensor.USE_VECTOR_API) { - return vectorDot(this, thisOffset, that, thatOffset, size); + return vectorDot(this, thisOffset, (ArrayFloatTensor) that, thatOffset, size); } else { return FloatTensor.scalarDot(this, thisOffset, that, thatOffset, size); } } - private static float vectorDot(Q4_0FloatTensor thiz, int thisOffset, FloatTensor that, int thatOffset, int size) { + private static float vectorDot(Q4_0FloatTensor thiz, int thisOffset, ArrayFloatTensor that, int thatOffset, int size) { float result = 0f; int j = 0; @@ -1664,7 +1941,7 @@ private static float vectorDot(Q4_0FloatTensor thiz, int thisOffset, FloatTensor int blockOffset = (thisOffset + j) / GGMLType.Q4_0.getBlockSize() * GGMLType.Q4_0.getTypeSize(); int upperBound = size / GGMLType.Q4_0.getBlockSize() * GGMLType.Q4_0.getBlockSize(); for (; j < upperBound; j += GGMLType.Q4_0.getBlockSize(), blockOffset += GGMLType.Q4_0.getTypeSize()) { - float wScaleValue = Float.float16ToFloat(thiz.memorySegment.get(JAVA_SHORT_LE, blockOffset)); + float wScaleValue = Float.float16ToFloat(readShort(thiz.memorySegment, blockOffset)); var wScale = FloatVector.broadcast(F_SPECIES, wScaleValue); var B_SPECIES = ByteVector.SPECIES_128; var wBytes = ByteVector.fromMemorySegment(B_SPECIES, thiz.memorySegment, blockOffset + Float16.BYTES, ByteOrder.LITTLE_ENDIAN); @@ -1737,8 +2014,8 @@ public float getFloat(int index) { int blockIndex = index / GGMLType.Q8_0.getBlockSize(); int withinBlockIndex = index % GGMLType.Q8_0.getBlockSize(); int blockOffset = blockIndex * GGMLType.Q8_0.getTypeSize(); - byte quant = memorySegment.get(ValueLayout.JAVA_BYTE, blockOffset + Float16.BYTES + withinBlockIndex); - float scale = Float.float16ToFloat(memorySegment.get(JAVA_SHORT_LE, blockOffset)); + byte quant = readByte(memorySegment, blockOffset + Float16.BYTES + withinBlockIndex); + float scale = Float.float16ToFloat(readShort(memorySegment, blockOffset)); return quant * scale; } @@ -1747,13 +2024,13 @@ public float getFloat(int index) { @Override public float dot(int thisOffset, FloatTensor that, int thatOffset, int size) { if (FloatTensor.USE_VECTOR_API) { - return vectorDot(this, thisOffset, that, thatOffset, size); + return vectorDot(this, thisOffset, (ArrayFloatTensor) that, thatOffset, size); } else { return FloatTensor.scalarDot(this, thisOffset, that, thatOffset, size); } } - private static float vectorDot(Q8_0FloatTensor thiz, int thisOffset, FloatTensor that, int thatOffset, int size) { + private static float vectorDot(Q8_0FloatTensor thiz, int thisOffset, ArrayFloatTensor that, int thatOffset, int size) { float result = 0f; int j = 0; @@ -1770,7 +2047,7 @@ private static float vectorDot(Q8_0FloatTensor thiz, int thisOffset, FloatTensor int blockOffset = (thisOffset + j) / GGMLType.Q8_0.getBlockSize() * GGMLType.Q8_0.getTypeSize(); int upperBound = size / GGMLType.Q8_0.getBlockSize() * GGMLType.Q8_0.getBlockSize(); for (; j < upperBound; j += GGMLType.Q8_0.getBlockSize(), blockOffset += GGMLType.Q8_0.getTypeSize()) { - float wScaleValue = Float.float16ToFloat(thiz.memorySegment.get(JAVA_SHORT_LE, blockOffset)); + float wScaleValue = Float.float16ToFloat(readShort(thiz.memorySegment, blockOffset)); var wScale = FloatVector.broadcast(F_SPECIES, wScaleValue); if (F_SPECIES.vectorBitSize() == 256) { var wBytes = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, thiz.memorySegment, blockOffset + Float16.BYTES, ByteOrder.LITTLE_ENDIAN); @@ -1855,7 +2132,7 @@ public FloatVector getFloatVector(VectorSpecies species, int index) { final class RoPE { public static Pair precomputeFreqsCis(int contextLength, int headSize, double theta, - boolean ropeScaling, float scaleFactor, float loFreqFactor, float hiFreqFactor, float oldContextLength) { + boolean ropeScaling, float scaleFactor, float loFreqFactor, float hiFreqFactor, float oldContextLength) { assert headSize % 2 == 0; float[] cr = new float[contextLength * (headSize / 2)]; float[] ci = new float[contextLength * (headSize / 2)]; @@ -2029,13 +2306,14 @@ public int sampleToken(FloatTensor logits) { */ class ChatFormat { - protected final Tokenizer tokenizer; - protected final int beginOfText; - protected final int endHeader; - protected final int startHeader; - protected final int endOfTurn; - protected final int endOfText; - protected final int endOfMessage; + final Tokenizer tokenizer; + final int beginOfText; + final int endHeader; + final int startHeader; + final int endOfTurn; + final int endOfText; + final int endOfMessage; + final Set stopTokens; public ChatFormat(Tokenizer tokenizer) { this.tokenizer = tokenizer; @@ -2046,6 +2324,7 @@ public ChatFormat(Tokenizer tokenizer) { this.endOfTurn = specialTokens.get("<|eot_id|>"); this.endOfText = specialTokens.get("<|end_of_text|>"); this.endOfMessage = specialTokens.getOrDefault("<|eom_id|>", -1); // only in 3.1 + this.stopTokens = Set.of(endOfText, endOfTurn); } public Tokenizer getTokenizer() { @@ -2053,7 +2332,7 @@ public Tokenizer getTokenizer() { } public Set getStopTokens() { - return Set.of(endOfText, endOfTurn); + return stopTokens; } public List encodeHeader(ChatFormat.Message message) { @@ -2100,4 +2379,65 @@ public String toString() { } } +/** + * Support for AOT preloading of GGUF metadata with GraalVM's Native Image. + * + *

+ * To preload a model at build time, pass {@code -Dllama.PreloadGGUF=/path/to/model.gguf} + * to the native-image builder command. At runtime, the preloaded model will be used + * iff the specified and preloaded file names (base name) match. + */ +final class AOT { + record PartialModel(String modelFileName, Llama model, long tensorDataOffset, Map tensorInfos) {} + private static final PartialModel PRELOADED_GGUF = preLoadGGUF(System.getProperty("llama.PreloadGGUF")); + + private static PartialModel preLoadGGUF(String modelPath) { + if (modelPath == null || modelPath.isEmpty()) { + return null; + } + try { + Path path = Path.of(modelPath); + if (!Files.exists(path) || !Files.isRegularFile(path)) { + throw new IllegalArgumentException("Cannot pre-load model: " + path); + } + GGUF gguf = GGUF.loadModel(path); + try (FileChannel fileChannel = FileChannel.open(path, StandardOpenOption.READ)) { + return new PartialModel( + path.getFileName().toString(), + ModelLoader.loadModel(path, fileChannel, gguf, Llama3.Options.DEFAULT_MAX_TOKENS, false), + gguf.getTensorDataOffset(), + gguf.getTensorInfos() + ); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + /** + * Tries to reuse a compatible AOT preloaded model. + * The file name (base name) must match with the preloaded file name. + * No checksum/hash is checked for performance reasons. + */ + public static Llama tryUsePreLoaded(Path modelPath, int contextLength) throws IOException { + AOT.PartialModel preLoaded = AOT.PRELOADED_GGUF; + if (preLoaded == null) { + return null; // no pre-loaded model stored + } + String optionsModel = modelPath.getFileName().toString(); + String preLoadedModel = preLoaded.modelFileName(); + if (!Objects.equals(optionsModel, preLoadedModel)) { + // Preloaded and specified model file names didn't match. + return null; + } + Llama baseModel = preLoaded.model(); + try (var timer = Timer.log("Load tensors from pre-loaded model"); + var fileChannel = FileChannel.open(modelPath, StandardOpenOption.READ)) { + // Load only the tensors (mmap slices). + Map tensorEntries = GGUF.loadTensors(fileChannel, preLoaded.tensorDataOffset(), preLoaded.tensorInfos()); + Llama.Weights weights = ModelLoader.loadWeights(tensorEntries, baseModel.configuration()); + return new Llama(baseModel.configuration().withContextLength(contextLength), baseModel.tokenizer(), weights); + } + } +} diff --git a/Makefile b/Makefile index 7e3c025..59301d2 100644 --- a/Makefile +++ b/Makefile @@ -2,15 +2,28 @@ ifdef JAVA_HOME JAVAC ?= ${JAVA_HOME}/bin/javac JAVA ?= ${JAVA_HOME}/bin/java JAR ?= ${JAVA_HOME}/bin/jar + NATIVE_IMAGE ?= ${JAVA_HOME}/bin/native-image endif JAVAC ?= javac JAVA ?= java JAR ?= jar +NATIVE_IMAGE ?= native-image -JAVA_COMPILE_OPTIONS = --enable-preview -source 21 -g --add-modules jdk.incubator.vector +JAVA_MAJOR_VERSION := $(shell $(JAVA) -version 2>&1 | head -n 1 | cut -d'"' -f2 | cut -d'.' -f1) + +JAVA_COMPILE_OPTIONS = --enable-preview -source $(JAVA_MAJOR_VERSION) -g --add-modules jdk.incubator.vector JAVA_RUNTIME_OPTIONS = --enable-preview --add-modules jdk.incubator.vector +ifeq ($(OS),Windows_NT) + EXE := .exe +else + EXE := +endif + +# Define the executable name +NATIVE_FILE := llama3$(EXE) + JAVA_MAIN_CLASS = com.llama4j.Llama3 JAR_FILE = llama3.jar @@ -37,7 +50,7 @@ run-jar-command: # Clean the target directory clean: rm -rf ./target - rm $(JAR_FILE) + rm $(JAR_FILE) $(NATIVE_FILE) # Compile the Java source files target/classes/com/llama4j/%.class: %.java @@ -47,10 +60,26 @@ target/classes/com/llama4j/%.class: %.java target/classes: mkdir -p target/classes +$(NATIVE_FILE): jar + $(NATIVE_IMAGE) \ + -H:+UnlockExperimentalVMOptions \ + -H:+VectorAPISupport \ + -H:+ForeignAPISupport \ + -O3 \ + -march=native \ + --enable-preview \ + --add-modules jdk.incubator.vector \ + --initialize-at-build-time='com.llama4j.AOT,com.llama4j.FloatTensor,com.llama4j.' \ + -Djdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK=0 \ + -Dllama.PreloadGGUF=$(PRELOAD_GGUF) \ + -jar $(JAR_FILE) \ + -o $(NATIVE_FILE) + # Make the target directory a dependency of the Java class files $(JAVA_CLASSES): target/classes compile: target/classes default: jar +native: $(NATIVE_FILE) .PHONY: compile clean jar run-command run-jar-command .SUFFIXES: .java .class .jar diff --git a/README.md b/README.md index 38bbc79..072f8af 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,9 @@ # Llama3.java -Practical [Llama 3](https://github.com/meta-llama/llama3) and [3.1](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1) inference implemented in a single Java file. +Practical [Llama 3](https://github.com/meta-llama/llama3), [3.1](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1) and [3.2](https://ai.meta.com/blog/llama-3-2-connect-2024-vision-edge-mobile-devices/) inference implemented in a single Java file.

- +

This project is the successor of [llama2.java](https://github.com/mukel/llama2.java) @@ -17,41 +17,50 @@ Besides the educational value, this project will be used to test and tune compil - [GGUF format](https://github.com/ggerganov/ggml/blob/master/docs/gguf.md) parser - Llama 3 tokenizer based on [minbpe](https://github.com/karpathy/minbpe) - Llama 3 inference with Grouped-Query Attention - - Support Llama 3.1 (ad-hoc RoPE scaling) + - Support Llama 3.1 (ad-hoc RoPE scaling) and 3.2 (tie word embeddings) - Support for Q8_0 and Q4_0 quantizations - Fast matrix-vector multiplication routines for quantized tensors using Java's [Vector API](https://openjdk.org/jeps/469) - Simple CLI with `--chat` and `--instruct` modes. + - Support for caching of states (e.g. the system-prompt) Here's the interactive `--chat` mode in action:

- +

## Setup -Download pure `Q4_0` and (optionally) `Q8_0` quantized .gguf files from: +Download pure `Q4_0` and (optionally) `Q8_0` quantized .gguf files from: + - https://huggingface.co/mukel/Llama-3.2-1B-Instruct-GGUF + - https://huggingface.co/mukel/Llama-3.2-3B-Instruct-GGUF - https://huggingface.co/mukel/Meta-Llama-3.1-8B-Instruct-GGUF - https://huggingface.co/mukel/Meta-Llama-3-8B-Instruct-GGUF -The `~4.3GB` pure `Q4_0` quantized model is recommended, please be gentle with [huggingface.co](https://huggingface.co) servers: +The pure `Q4_0` quantized models are recommended, except for the very small models (1B), please be gentle with [huggingface.co](https://huggingface.co) servers: ``` -# Llama 3.1 +# Llama 3.2 (3B) +curl -L -O https://huggingface.co/mukel/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-Q4_0.gguf + +# Llama 3.2 (1B) +curl -L -O https://huggingface.co/mukel/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q8_0.gguf + +# Llama 3.1 (8B) curl -L -O https://huggingface.co/mukel/Meta-Llama-3.1-8B-Instruct-GGUF/resolve/main/Meta-Llama-3.1-8B-Instruct-Q4_0.gguf -# Llama 3 +# Llama 3 (8B) curl -L -O https://huggingface.co/mukel/Meta-Llama-3-8B-Instruct-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf -# Optionally download the Q8_0 quantized model ~8GB -# curl -L -O https://huggingface.co/mukel/Meta-Llama-3-8B-Instruct-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q8_0.gg +# Optionally download the Q8_0 quantized models +# curl -L -O https://huggingface.co/mukel/Meta-Llama-3-8B-Instruct-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q8_0.gguf # curl -L -O https://huggingface.co/mukel/Meta-Llama-3.1-8B-Instruct-GGUF/resolve/main/Meta-Llama-3.1-8B-Instruct-Q8_0.gguf ``` #### Optional: quantize to pure `Q4_0` manually -In the wild, `Q8_0` quantizations are fine, but `Q4_0` quantizations are rarely pure e.g. the `output.weights` tensor is quantized with `Q6_K`, instead of `Q4_0`. +In the wild, `Q8_0` quantizations are fine, but `Q4_0` quantizations are rarely pure e.g. the `token_embd.weights`/`output.weights` tensor are quantized with `Q6_K`, instead of `Q4_0`. A **pure** `Q4_0` quantization can be generated from a high precision (F32, F16, BFLOAT16) .gguf source -with the `quantize` utility from [llama.cpp](https://github.com/ggerganov/llama.cpp) as follows: +with the `llama-quantize` utility from [llama.cpp](https://github.com/ggerganov/llama.cpp) as follows: ```bash ./llama-quantize --pure ./Meta-Llama-3-8B-Instruct-F32.gguf ./Meta-Llama-3-8B-Instruct-Q4_0.gguf Q4_0 @@ -74,7 +83,7 @@ chmod +x Llama3.java ## Run from source ```bash -java --enable-preview --source 21 --add-modules jdk.incubator.vector LLama3.java -i --model Meta-Llama-3-8B-Instruct-Q4_0.gguf +java --enable-preview --source 21 --add-modules jdk.incubator.vector Llama3.java -i --model Meta-Llama-3-8B-Instruct-Q4_0.gguf ``` #### Optional: Makefile + manually build and run @@ -90,64 +99,88 @@ Run the resulting `llama3.jar` as follows: java --enable-preview --add-modules jdk.incubator.vector -jar llama3.jar --help ``` +### GraalVM Native Image + +Compile to native via `make` (recommended): + +```bash +make native +``` +Or directly: + +```bash +native-image -H:+UnlockExperimentalVMOptions -H:+VectorAPISupport -H:+ForeignAPISupport -O3 -march=native --enable-preview --add-modules jdk.incubator.vector --initialize-at-build-time=com.llama4j.FloatTensor -Djdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK=0 -jar llama3.jar -o llama3 +``` + +Run as Native Image: + +```bash +./llama3 --model Llama-3.2-1B-Instruct-Q8_0 --chat +``` + +### AOT model preloading + +`Llama3.java` supports AOT model preloading, enabling **0-overhead, instant inference, with minimal TTFT (time-to-first-token)**. + +To AOT pre-load a GGUF model: +```bash +PRELOAD_GGUF=/path/to/model.gguf make native +``` + +A specialized, larger binary will be generated, with no parsing overhead for that particular model. +It can still run other models, although incurring the usual parsing overhead. + +### State cache +Save the internal states of a system prompt. + +```bash +java --enable-preview --source 21 --add-modules jdk.incubator.vector Llama3.java --model Llama-3.2-3B-Instruct-Q4_0.gguf --write-state-cache true --state-cache prompt_cucko.ggsc -sp "You are a helpful cuckoo clock at the floor wall in front of the living-room. You like to talk about Java and its bytecode." -p "Hi!" +``` + +Read the stored states of the system prompt. + +```bash +java --enable-preview --source 21 --add-modules jdk.incubator.vector Llama3.java --model Llama-3.2-3B-Instruct-Q4_0.gguf --read-state-cache true --state-cache prompt_cucko.ggsc -sp "You are a helpful cuckoo clock at the floor wall in front of the living-room. You like to talk about Java and its bytecode." -i +``` + ## Performance -**Important Note** -On GraalVM, please note that the Graal compiler doesn't support the Vector API yet, run with `-Dllama.VectorAPI=false`, but expect sub-optimal performance. -Vanilla OpenJDK 21+ is recommended for now, which supports the Vector API. +GraalVM now supports more [Vector API](https://openjdk.org/jeps/469) operations. To give it a try, you need GraalVM for JDK 24 – get the EA builds from [`oracle-graalvm-ea-builds`](https://github.com/graalvm/oracle-graalvm-ea-builds) or sdkman: `sdk install java 24.ea.15-graal`. -### llama.cpp +#### llama.cpp -Vanilla `llama.cpp` built with `make -j 20`. +Vanilla `llama.cpp` built with `make`. ```bash -./main --version -version: 2879 (4f026363) -built with cc (GCC) 13.2.1 20230801 for x86_64-pc-linux-gnu +./llama-cli --version 130 ↵ +version: 3862 (3f1ae2e3) +built with cc (GCC) 14.2.1 20240805 for x86_64-pc-linux-gnu ``` Executed as follows: ```bash -./main -m ../Meta-Llama-3-8B-Instruct-Q4_0.gguf \ - -n 512 \ - -s 42 \ - -p "<|start_of_header_id|>user<|end_of_header_id|>Why is the sky blue?<|eot_id|><|start_of_header_id|>assistant<|end_of_header_id|>\n\n" \ - --interactive-specials +./llama-bench -m Llama-3.2-1B-Instruct-Q4_0.gguf -p 0 -n 128 ``` -Collected the **"eval time"** metric in tokens\s. -### Llama3.java -Running on OpenJDK 21.0.2. +#### Llama3.java ```bash -jbang Llama3.java \ - --model ./Meta-Llama-3-8B-Instruct-Q4_0.gguf \ - --max-tokens 512 \ +taskset -c 0-15 ./llama3 \ + --model ./Llama-3-1B-Instruct-Q4_0.gguf \ + --max-tokens 128 \ --seed 42 \ --stream false \ --prompt "Why is the sky blue?" ``` -### Results - -#### Notebook Intel 13900H 6pC+8eC/20T 64GB (5200) Linux 6.6.26 -| Model | tokens/s | Implementation | -|----------------------------------|----------|------------------| -| Llama-3-8B-Instruct-Q4_0.gguf | 7.53 | llama.cpp | -| Llama-3-8B-Instruct-Q4_0.gguf | 6.95 | llama3.java | -| Llama-3-8B-Instruct-Q8_0.gguf | 5.16 | llama.cpp | -| Llama-3-8B-Instruct-Q8_0.gguf | 4.02 | llama3.java | - -#### Workstation AMD 3950X 16C/32T 64GB (3200) Linux 6.6.25 +Hardware specs: 2019 AMD Ryzen 3950X 16C/32T 64GB (3800) Linux 6.6.47. ****Notes** -*Running on a single CCD e.g. `taskset -c 0-15 jbang Llama3.java ...` since inference is constrained by memory bandwidth.* - -| Model | tokens/s | Implementation | -|----------------------------------|----------|------------------| -| Llama-3-8B-Instruct-Q4_0.gguf | 9.26 | llama.cpp | -| Llama-3-8B-Instruct-Q4_0.gguf | 8.03 | llama3.java | -| Llama-3-8B-Instruct-Q8_0.gguf | 5.79 | llama.cpp | -| Llama-3-8B-Instruct-Q8_0.gguf | 4.92 | llama3.java | +*Running on a single CCD e.g. `taskset -c 0-15 ./llama3 ...` since inference is constrained by memory bandwidth.* + +### Results +

+ +

## License