From 9d25ca1d21768cd36926bbbc1aa7eaaa70f87646 Mon Sep 17 00:00:00 2001 From: Ja Morphy Date: Sun, 6 Apr 2025 01:06:59 -0700 Subject: [PATCH] A forward pass Run forward passes with dummy codes. Output tensor shapes (raw audio samples) seem to match expected shape given number of input frames. Attempts with Orpheus to be done soon. The gguf used in this commit is at: https://huggingface.co/jamorphy/snac-fwd-pass-devel-gguf --- convert_hf_to_gguf.py | 2 +- examples/tts/orpheus-tts.cpp | 343 +++++++--------------- ggml/src/ggml-cpu/ggml-cpu.c | 1 + include/llama.h | 2 + src/llama-context.cpp | 36 ++- src/llama-context.h | 2 + src/llama-model.cpp | 551 ++++++++++++++++------------------- src/llama-model.h | 4 +- 8 files changed, 396 insertions(+), 545 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 01ec22aa3cc28..093e769e338f3 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2494,7 +2494,7 @@ def set_vocab(self): def set_gguf_parameters(self): super().set_gguf_parameters() - self.gguf_writer.add_vocab_size (4096) # TODO: Fix + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) self.gguf_writer.add_uint32("snac.quantizer.codebook_size", self.hparams["codebook_size"]) self.gguf_writer.add_uint32("snac.quantizer.codebook_dim", self.hparams["codebook_dim"]) self.gguf_writer.add_embedding_length(self.hparams["decoder_dim"]) # 1024 diff --git a/examples/tts/orpheus-tts.cpp b/examples/tts/orpheus-tts.cpp index 45595e9552fc0..a7f0e16dfa296 100644 --- a/examples/tts/orpheus-tts.cpp +++ b/examples/tts/orpheus-tts.cpp @@ -1,6 +1,5 @@ #include "common.h" #include "llama.h" -#include "llama-impl.h" #include "log.h" #include "arg.h" #include "sampling.h" @@ -19,148 +18,30 @@ #include #include -std::vector redistribute_codes(const std::vector& raw_codes) { - std::vector snac_codes; - for (size_t i = 0; i < raw_codes.size(); i += 7) { - // Ensure we have a full frame (7 codes) - if (i + 6 >= raw_codes.size()) break; - - // Frame offsets (per notebook) - snac_codes.push_back(raw_codes[i]); // Codebook 0 (no offset) - snac_codes.push_back(raw_codes[i+1] - 4096); // Codebook 1 - snac_codes.push_back(raw_codes[i+2] - 8192); // Codebook 2 - snac_codes.push_back(raw_codes[i+3] - 12288); // Codebook 2 - snac_codes.push_back(raw_codes[i+4] - 16384); // Codebook 1 - snac_codes.push_back(raw_codes[i+5] - 20480); // Codebook 2 - snac_codes.push_back(raw_codes[i+6] - 24576); // Codebook 2 - } - return snac_codes; -} - -static std::vector embd_to_audio( - const float * embd, - const int n_codes, - const int n_embd, - const int n_thread); -static bool save_wav16(const std::string & fname, const std::vector & data, int sample_rate); -static void fill_hann_window(int length, bool periodic, float * output); -static void irfft(int n, const float * inp_cplx, float * out_real); -static void fold(const std::vector & data, int64_t n_out, int64_t n_win, int64_t n_hop, int64_t n_pad, std::vector & output); - -static void print_usage(int /*argc*/, char **argv) { - LOG("\nexample usage:\n"); - LOG("\n %s -m model.gguf -mv vocoder.gguf -p \"Hello world\"\n", argv[0]); - LOG("\n"); -} - -static void prompt_add(std::vector &prompt, const llama_vocab *vocab, const std::string &txt, bool add_special, bool parse_special) { - auto tmp = common_tokenize(vocab, txt, add_special, parse_special); - prompt.insert(prompt.end(), tmp.begin(), tmp.end()); -} - - -// // Include embd_to_audio and save_wav16 from tts.cpp (for now) -static std::vector embd_to_audio( - const float * embd, - const int n_codes, - const int n_embd, - const int n_thread) { - const int n_fft = 1280; - const int n_hop = 320; - const int n_win = 1280; - const int n_pad = (n_win - n_hop)/2; - const int n_out = (n_codes - 1)*n_hop + n_win; - - std::vector hann(n_fft); - fill_hann_window(hann.size(), true, hann.data()); - - int n_spec = n_embd*n_codes; - - std::vector E (n_spec); - std::vector S (n_spec); - std::vector ST(n_spec); - - for (int l = 0; l < n_codes; ++l) { - for (int k = 0; k < n_embd; ++k) { - E[k*n_codes + l] = embd[l*n_embd + k]; - } - } - - for (int k = 0; k < n_embd/2; ++k) { - for (int l = 0; l < n_codes; ++l) { - float mag = E[(k )*n_codes + l]; - float phi = E[(k + n_embd/2)*n_codes + l]; - mag = exp(mag); - if (mag > 1e2) { - mag = 1e2; - } - S[2*(k*n_codes + l) + 0] = mag*cosf(phi); - S[2*(k*n_codes + l) + 1] = mag*sinf(phi); - } - } - - for (int l = 0; l < n_codes; ++l) { - for (int k = 0; k < n_embd/2; ++k) { - ST[l*n_embd + 2*k + 0] = S[2*(k*n_codes + l) + 0]; - ST[l*n_embd + 2*k + 1] = S[2*(k*n_codes + l) + 1]; - } - } - - std::vector res (n_codes*n_fft); - std::vector hann2(n_codes*n_fft); - - std::vector workers(n_thread); - for (int i = 0; i < n_thread; ++i) { - workers[i] = std::thread([&, i]() { - for (int l = i; l < n_codes; l += n_thread) { - irfft(n_fft, ST.data() + l*n_embd, res.data() + l*n_fft); - for (int j = 0; j < n_fft; ++j) { - res [l*n_fft + j] *= hann[j]; - hann2[l*n_fft + j] = hann[j] * hann[j]; - } - } - }); - } - for (int i = 0; i < n_thread; ++i) { - workers[i].join(); - } - - std::vector audio; - std::vector env; - - fold(res, n_out, n_win, n_hop, n_pad, audio); - fold(hann2, n_out, n_win, n_hop, n_pad, env); - - for (size_t i = 0; i < audio.size(); ++i) { - audio[i] /= env[i]; - } - - return audio; -} - -static bool save_wav16(const std::string & fname, const std::vector & data, int sample_rate) { +struct wav_header { + char riff[4] = {'R', 'I', 'F', 'F'}; + uint32_t chunk_size; + char wave[4] = {'W', 'A', 'V', 'E'}; + char fmt[4] = {'f', 'm', 't', ' '}; + uint32_t fmt_chunk_size = 16; + uint16_t audio_format = 1; // PCM + uint16_t num_channels = 1; // Mono + uint32_t sample_rate; + uint32_t byte_rate; + uint16_t block_align; + uint16_t bits_per_sample = 16; + char data[4] = {'d', 'a', 't', 'a'}; + uint32_t data_size; +}; + +static bool save_wav16(const std::string &fname, const std::vector &data, int sample_rate) { std::ofstream file(fname, std::ios::binary); if (!file) { LOG_ERR("%s: Failed to open file '%s' for writing.\n", __func__, fname.c_str()); return false; } - struct wav_header { - char riff[4] = {'R', 'I', 'F', 'F'}; - uint32_t chunk_size; - char wave[4] = {'W', 'A', 'V', 'E'}; - char fmt[4] = {'f', 'm', 't', ' '}; - uint32_t fmt_chunk_size = 16; - uint16_t audio_format = 1; // PCM - uint16_t num_channels = 1; // Mono - uint32_t sample_rate; - uint32_t byte_rate; - uint16_t block_align; - uint16_t bits_per_sample = 16; - char data[4] = {'d', 'a', 't', 'a'}; - uint32_t data_size; - } header; - + wav_header header; header.sample_rate = sample_rate; header.byte_rate = header.sample_rate * header.num_channels * (header.bits_per_sample / 8); header.block_align = header.num_channels * (header.bits_per_sample / 8); @@ -169,95 +50,49 @@ static bool save_wav16(const std::string & fname, const std::vector & dat file.write(reinterpret_cast(&header), sizeof(header)); - for (const auto & sample : data) { - int16_t pcm_sample = static_cast(std::clamp(sample * 32767.0, -32768.0, 32767.0)); + for (const auto &sample : data) { + int16_t pcm_sample = static_cast(std::clamp(sample * 32767.0f, -32768.0f, 32767.0f)); file.write(reinterpret_cast(&pcm_sample), sizeof(pcm_sample)); } return file.good(); } -// Supporting functions from tts.cpp (for embd_to_audio) -static void fill_hann_window(int length, bool periodic, float * output) { - int offset = -1; - if (periodic) { - offset = 0; - } - for (int i = 0; i < length; i++) { - output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); - } +std::vector redistribute_codes(const std::vector& raw_codes) { + std::vector snac_codes; + for (size_t i = 0; i < raw_codes.size(); i += 7) { + if (i + 6 >= raw_codes.size()) break; + + // Subtract 128266 base and layer-specific offsets + snac_codes.push_back(raw_codes[i] - 128266); // Layer 1: offset 0 + snac_codes.push_back(raw_codes[i + 1] - 128266 - 4096); // Layer 2: offset 4096 + snac_codes.push_back(raw_codes[i + 2] - 128266 - 8192); // Layer 3: offset 8192 + snac_codes.push_back(raw_codes[i + 3] - 128266 - 12288); // Layer 3: offset 12288 + snac_codes.push_back(raw_codes[i + 4] - 128266 - 16384); // Layer 2: offset 16384 + snac_codes.push_back(raw_codes[i + 5] - 128266 - 20480); // Layer 3: offset 20480 + snac_codes.push_back(raw_codes[i + 6] - 128266 - 24576); // Layer 3: offset 24576 + } + return snac_codes; } -static void twiddle(float * real, float * imag, int k, int N) { - float angle = 2 * M_PI * k / N; - *real = cos(angle); - *imag = sin(angle); -} - -static void irfft(int n, const float * inp_cplx, float * out_real) { - int N = n / 2 + 1; - - std::vector real_input(N); - std::vector imag_input(N); - for (int i = 0; i < N; ++i) { - real_input[i] = inp_cplx[2 * i]; - imag_input[i] = inp_cplx[2 * i + 1]; - } - - std::vector real_output(n); - std::vector imag_output(n); - - for (int k = 0; k < n; ++k) { - real_output[k] = 0.0f; - imag_output[k] = 0.0f; - for (int m = 0; m < N; ++m) { - float twiddle_real; - float twiddle_imag; - - twiddle(&twiddle_real, &twiddle_imag, k * m, n); - - real_output[k] += real_input[m] * twiddle_real - imag_input[m] * twiddle_imag; - imag_output[k] += real_input[m] * twiddle_imag + imag_input[m] * twiddle_real; - } - } - - for (int i = 0; i < n; ++i) { - out_real[i] = real_output[i] / N; - } +static void print_usage(int /*argc*/, char **argv) { + LOG("\nexample usage:\n"); + LOG("\n %s -m model.gguf -mv vocoder.gguf -p \"Hello world\"\n", argv[0]); + LOG("\n"); } -static void fold(const std::vector & data, int64_t n_out, int64_t n_win, int64_t n_hop, int64_t n_pad, std::vector & output) { - int64_t output_height = n_out; - int64_t kernel_w = n_win; - int64_t stride_w = n_hop; - int64_t width = n_out; - - output.resize(width, 0.0f); - - int64_t col_idx = 0; - for (int64_t w_col = 0; w_col < width; ++w_col) { - int64_t start = w_col * stride_w - n_pad; - int64_t end = start + kernel_w; - - for (int64_t w_im = start; w_im < end; ++w_im) { - if (w_im >= 0 && w_im < output_height && col_idx < (int64_t) data.size()) { - output[w_im] += data[col_idx]; - } - col_idx++; - } - } - - output.resize(n_out - 2 * n_pad); +static void prompt_add(std::vector &prompt, const llama_vocab *vocab, const std::string &txt, bool add_special, bool parse_special) { + auto tmp = common_tokenize(vocab, txt, add_special, parse_special); + prompt.insert(prompt.end(), tmp.begin(), tmp.end()); } int main(int argc, char **argv) { common_params params; - + params.model = "models/orpheus-3b-0.1-ft-q4_k_m.gguf"; - params.vocoder.model = "models/snac-vocab.gguf"; + params.vocoder.model = "models/snac-fwd-pass-devel.gguf"; params.out_file = "output.wav"; - params.n_predict = 1200; params.sampling.top_k = 4; params.sampling.samplers = { COMMON_SAMPLER_TYPE_TOP_K }; params.n_batch = 4096; @@ -265,7 +100,8 @@ int main(int argc, char **argv) { common_init(); llama_backend_init(); llama_numa_init(params.numa); - + + common_init_result orpheus_init_ttc = common_init_from_params(params); llama_model * model_ttc = NULL; @@ -290,17 +126,15 @@ int main(int argc, char **argv) { prompt_add(tokens, vocab, "", false, true); // Emotion tag tokens.push_back(128009); // <|eot_id|> tokens.push_back(128260); // <|endofhuman|> - + llama_model * model_cts = NULL; llama_context * ctx_cts = NULL; params.model = params.vocoder.model; - params.n_batch = 2; params.embedding = true; - // disable warmup, SNAC doesn't care about BOS or EOS tokens; - params.warmup = false; + params.warmup = false; // SNAC doesn't care about BOS or EOS tokens common_init_result snac_init_cts = common_init_from_params(params); LOG_INF("SNAC model loaded: %s\n", params.model.c_str()); @@ -308,35 +142,80 @@ int main(int argc, char **argv) { model_cts = snac_init_cts.model.get(); ctx_cts = snac_init_cts.context.get(); - std::vector speech_codes = {100, 4200, 8500, 12500, 16500, 21000, 25000, - 200, 4300, 8600, 12600, 16600, 21111, 25100}; - - std::vector snac_codes = redistribute_codes(speech_codes); - - const int n_codes = speech_codes.size(); - const int batch_size = n_codes; - - llama_batch batch = llama_batch_init(batch_size, 0, 1); - - for (size_t i = 0; i < n_codes; ++i) { + // TODO: Use real orpheus codes + // Just some random numbers for testing + std::vector orpheus_codes = { + // Frame 1, 7 codes per frame + 128266 + 100, // L1: 100 + 128266 + 4096 + 200, // L2: 200 + 128266 + 8192 + 300, // L3: 300 + 128266 + 12288 + 400,// L3: 400 + 128266 + 16384 + 500,// L2: 500 + 128266 + 20480 + 600,// L3: 600 + 128266 + 24576 + 700,// L3: 700 + // Frame 2 + 128266 + 150, 128266 + 4096 + 250, 128266 + 8192 + 350, 128266 + 12288 + 450, + 128266 + 16384 + 550, 128266 + 20480 + 650, 128266 + 24576 + 750, + // Frame 3 + 128266 + 110, 128266 + 4096 + 210, 128266 + 8192 + 310, 128266 + 12288 + 410, + 128266 + 16384 + 510, 128266 + 20480 + 610, 128266 + 24576 + 710, + // Frame 4 + 128266 + 120, 128266 + 4096 + 220, 128266 + 8192 + 320, 128266 + 12288 + 420, + 128266 + 16384 + 520, 128266 + 20480 + 620, 128266 + 24576 + 720, + // Frame 5 + 128266 + 130, 128266 + 4096 + 230, 128266 + 8192 + 330, 128266 + 12288 + 430, + 128266 + 16384 + 530, 128266 + 20480 + 630, 128266 + 24576 + 730, + // Frame 6 + 128266 + 140, 128266 + 4096 + 240, 128266 + 8192 + 340, 128266 + 12288 + 440, + 128266 + 16384 + 540, 128266 + 20480 + 640, 128266 + 24576 + 740, + // Frame 7 + 128266 + 160, 128266 + 4096 + 260, 128266 + 8192 + 360, 128266 + 12288 + 460, + 128266 + 16384 + 560, 128266 + 20480 + 660, 128266 + 24576 + 760, + // Frame 8 + 128266 + 170, 128266 + 4096 + 270, 128266 + 8192 + 370, 128266 + 12288 + 470, + 128266 + 16384 + 570, 128266 + 20480 + 670, 128266 + 24576 + 770, + // Frame 9 + 128266 + 180, 128266 + 4096 + 280, 128266 + 8192 + 380, 128266 + 12288 + 480, + 128266 + 16384 + 580, 128266 + 20480 + 680, 128266 + 24576 + 780, + // Frame 10 + 128266 + 190, 128266 + 4096 + 290, 128266 + 8192 + 390, 128266 + 12288 + 490, + 128266 + 16384 + 590, 128266 + 20480 + 690, 128266 + 24576 + 790 + }; + + std::vector snac_codes = redistribute_codes(orpheus_codes); + + const int batch_size = snac_codes.size(); + + llama_batch batch = llama_batch_init(batch_size, 0, 1); + + for (size_t i = 0; i < batch_size; ++i) { common_batch_add(batch, snac_codes[i], i, {0}, true); } LOG_INF("Batch before decode: n_tokens = %d\n", batch.n_tokens); - if (llama_decode(ctx_cts, batch) != 0) { /* error */ } - - if (llama_decode(ctx_cts, batch) != 0) { /* error */ } - GGML_ASSERT(batch.n_tokens == n_codes); + GGML_ASSERT(batch.n_tokens == batch_size); batch.logits[batch.n_tokens - 1] = true; - + if (llama_decode(ctx_cts, batch) != 0) { LOG_ERR("Failed to decode SNAC batch\n"); return 1; } - llama_synchronize(ctx_cts); - LOG_INF("SNAC decode completed\n"); + llama_synchronize(ctx_cts); + + float* embd = llama_get_embeddings(ctx_cts); + if (!embd) { + LOG_ERR("No embeddings available\n"); + return 1; + } + + int n_samples = llama_get_n_outputs(ctx_cts); + std::vector audio(n_samples); + LOG_INF("n_samples: %i\n", n_samples); + memcpy(audio.data(), embd, n_samples * sizeof(float)); + + save_wav16(params.out_file, audio, 24000); llama_batch_free(batch); llama_backend_free(); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index def6eb3423c61..7bded06f88a94 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -14894,6 +14894,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: case GGML_OP_LEAKY_RELU: + case GGML_OP_SNAKE: { n_tasks = 1; } break; diff --git a/include/llama.h b/include/llama.h index 6a44be404d914..f98f1910bcf1c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -629,6 +629,8 @@ extern "C" { llama_seq_id * cells_sequences; }; + LLAMA_API int32_t llama_get_n_outputs(struct llama_context * ctx); + // Create an empty KV cache view. (use only for debugging purposes) LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 5bec63e2e79ff..d15061655da39 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -851,6 +851,10 @@ float * llama_context::get_logits_ith(int32_t i) { } } +int32_t llama_context::get_n_outputs() { + return n_outputs; +} + float * llama_context::get_embeddings() { // reorder embeddings for backward compatibility output_reorder(); @@ -1403,10 +1407,21 @@ int llama_context::decode(llama_batch & inp_batch) { GGML_ASSERT(embd != nullptr); float * embd_out = embd + n_outputs_prev*n_embd; - if (n_outputs) { - GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); + if (model.arch == LLM_ARCH_SNAC_DEC) { + // TODO: hack, SNAC outputs audio samples, not embeddings + // Rely on n_outputs for now, but perhaps add an `n_samples_snac` to + // llama_context to avoid doing these checks + int64_t n_samples = t_embd->ne[0]; + if (n_samples > 0) { + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_samples * sizeof(float)); + n_outputs = n_samples; // Update for downstream + } + } else { + if (n_outputs) { + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs * n_embd * sizeof(float)); + } } } break; case LLAMA_POOLING_TYPE_MEAN: @@ -1471,8 +1486,11 @@ int llama_context::decode(llama_batch & inp_batch) { } } - // set to total number of outputs in the batch, for use in llama_get_logits_ith - n_outputs = n_outputs_all; + // TODO: Hack for now to avoid overwriting n_outputs in previous step + if (model.arch != LLM_ARCH_SNAC_DEC) { + // set to total number of outputs in the batch, for use in llama_get_logits_ith + n_outputs = n_outputs_all; + } // wait for the computation to finish (automatically done when obtaining the model output) //synchronize(); @@ -2417,6 +2435,12 @@ float * llama_get_logits_ith(llama_context * ctx, int32_t i) { return ctx->get_logits_ith(i); } +int32_t llama_get_n_outputs(struct llama_context * ctx) { + ctx->synchronize(); + + return ctx->get_n_outputs(); +} + float * llama_get_embeddings(llama_context * ctx) { ctx->synchronize(); diff --git a/src/llama-context.h b/src/llama-context.h index 04facb544cb1a..ff9ad663d1fe5 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -48,6 +48,8 @@ struct llama_context { float * get_logits(); float * get_logits_ith(int32_t i); + int32_t get_n_outputs(); + float * get_embeddings(); float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index bee6e6bd359b4..4051c42852039 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1319,13 +1319,14 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_SNAC_DEC: { - hparams.n_channels = {768, 1024, 512, 256, 128, 64, 1}; // From decoder_channel_dims + // TODO: Read from GGUF + hparams.n_channels = {768, 1024, 512, 256, 128, 64, 1}; hparams.upsample_rates = {8, 8, 4, 2}; hparams.n_embd = 768; hparams.n_layer = 8; - // Dummy KV cache params to satisfy llama.cpp - for (uint32_t i = 0; i < 7; ++i) { // n_total_layers = 8 + // Dummy KV cache params to satisfy init error + for (uint32_t i = 0; i < hparams.n_layer; ++i) { hparams.n_head_arr[i] = 1; hparams.n_head_kv_arr[i] = 1; } @@ -3716,8 +3717,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {8, 4096, 1}, 0); - hparams.n_channels = {768, 1024, 512, 256, 128, 64, 1}; - // Quantizer projection tensors (0, 1, 2) for (int qid = 0; qid < 3; ++qid) { fprintf(stderr, "%s: Loading quantizer %d tensors\n", __func__, qid); @@ -3782,49 +3781,49 @@ bool llama_model::load_tensors(llama_model_loader & ml) { break; case 3: // Block 3: Residual Unit 1 { - int res_unit_idx = 0; auto & res_unit = layer.decoder_blocks[bid].res_units[res_unit_idx]; - res_unit.alpha1 = create_tensor(tn(LLM_TENSOR_RES_SNAKE1_A, i, bid), {1, n_out, 1}, 0); - res_unit.conv1_w = create_tensor(tn(LLM_TENSOR_RES_CONV1_W, i, bid), {7, 1, n_out}, 0); - res_unit.conv1_s = create_tensor(tn(LLM_TENSOR_RES_CONV1_S, i, bid), {1, 1, n_out}, 0); - res_unit.conv1_b = create_tensor(tn(LLM_TENSOR_RES_CONV1_B, i, bid), {n_out}, 0); - res_unit.alpha2 = create_tensor(tn(LLM_TENSOR_RES_SNAKE2_A, i, bid), {1, n_out, 1}, 0); - res_unit.conv2_w = create_tensor(tn(LLM_TENSOR_RES_CONV2_W, i, bid), {1, n_out, n_out}, 0); - res_unit.conv2_s = create_tensor(tn(LLM_TENSOR_RES_CONV2_S, i, bid), {1, 1, n_out}, 0); - res_unit.conv2_b = create_tensor(tn(LLM_TENSOR_RES_CONV2_B, i, bid), {n_out}, 0); + auto & ru = layer.decoder_blocks[bid].res_unit; + ru.alpha1 = create_tensor(tn(LLM_TENSOR_RES_SNAKE1_A, i, bid), {1, n_out, 1}, 0); + ru.conv1_w = create_tensor(tn(LLM_TENSOR_RES_CONV1_W, i, bid), {7, 1, n_out}, 0); + ru.conv1_s = create_tensor(tn(LLM_TENSOR_RES_CONV1_S, i, bid), {1, 1, n_out}, 0); + ru.conv1_b = create_tensor(tn(LLM_TENSOR_RES_CONV1_B, i, bid), {n_out}, 0); + ru.alpha2 = create_tensor(tn(LLM_TENSOR_RES_SNAKE2_A, i, bid), {1, n_out, 1}, 0); + ru.conv2_w = create_tensor(tn(LLM_TENSOR_RES_CONV2_W, i, bid), {1, n_out, n_out}, 0); + ru.conv2_s = create_tensor(tn(LLM_TENSOR_RES_CONV2_S, i, bid), {1, 1, n_out}, 0); + ru.conv2_b = create_tensor(tn(LLM_TENSOR_RES_CONV2_B, i, bid), {n_out}, 0); } break; case 4: // Block 4: Residual Unit 2 { - int res_unit_idx = 1; auto & res_unit = layer.decoder_blocks[bid].res_units[res_unit_idx]; - res_unit.alpha1 = create_tensor(tn(LLM_TENSOR_RES_SNAKE1_A_B4, i, bid), {1, n_out, 1}, 0); - res_unit.conv1_w = create_tensor(tn(LLM_TENSOR_RES_CONV1_W_B4, i, bid), {7, 1, n_out}, 0); - res_unit.conv1_s = create_tensor(tn(LLM_TENSOR_RES_CONV1_S_B4, i, bid), {1, 1, n_out}, 0); - res_unit.conv1_b = create_tensor(tn(LLM_TENSOR_RES_CONV1_B_B4, i, bid), {n_out}, 0); - res_unit.alpha2 = create_tensor(tn(LLM_TENSOR_RES_SNAKE2_A_B4, i, bid), {1, n_out, 1}, 0); - res_unit.conv2_w = create_tensor(tn(LLM_TENSOR_RES_CONV2_W_B4, i, bid), {1, n_out, n_out}, 0); - res_unit.conv2_s = create_tensor(tn(LLM_TENSOR_RES_CONV2_S_B4, i, bid), {1, 1, n_out}, 0); - res_unit.conv2_b = create_tensor(tn(LLM_TENSOR_RES_CONV2_B_B4, i, bid), {n_out}, 0); + auto & ru = layer.decoder_blocks[bid].res_unit; + ru.alpha1 = create_tensor(tn(LLM_TENSOR_RES_SNAKE1_A_B4, i, bid), {1, n_out, 1}, 0); + ru.conv1_w = create_tensor(tn(LLM_TENSOR_RES_CONV1_W_B4, i, bid), {7, 1, n_out}, 0); + ru.conv1_s = create_tensor(tn(LLM_TENSOR_RES_CONV1_S_B4, i, bid), {1, 1, n_out}, 0); + ru.conv1_b = create_tensor(tn(LLM_TENSOR_RES_CONV1_B_B4, i, bid), {n_out}, 0); + ru.alpha2 = create_tensor(tn(LLM_TENSOR_RES_SNAKE2_A_B4, i, bid), {1, n_out, 1}, 0); + ru.conv2_w = create_tensor(tn(LLM_TENSOR_RES_CONV2_W_B4, i, bid), {1, n_out, n_out}, 0); + ru.conv2_s = create_tensor(tn(LLM_TENSOR_RES_CONV2_S_B4, i, bid), {1, 1, n_out}, 0); + ru.conv2_b = create_tensor(tn(LLM_TENSOR_RES_CONV2_B_B4, i, bid), {n_out}, 0); } break; case 5: // Block 5: Residual Unit 3 { - int res_unit_idx = 2; auto & res_unit = layer.decoder_blocks[bid].res_units[res_unit_idx]; - res_unit.alpha1 = create_tensor(tn(LLM_TENSOR_RES_SNAKE1_A_B5, i, bid), {1, n_out, 1}, 0); - res_unit.conv1_w = create_tensor(tn(LLM_TENSOR_RES_CONV1_W_B5, i, bid), {7, 1, n_out}, 0); - res_unit.conv1_s = create_tensor(tn(LLM_TENSOR_RES_CONV1_S_B5, i, bid), {1, 1, n_out}, 0); - res_unit.conv1_b = create_tensor(tn(LLM_TENSOR_RES_CONV1_B_B5, i, bid), {n_out}, 0); - res_unit.alpha2 = create_tensor(tn(LLM_TENSOR_RES_SNAKE2_A_B5, i, bid), {1, n_out, 1}, 0); - res_unit.conv2_w = create_tensor(tn(LLM_TENSOR_RES_CONV2_W_B5, i, bid), {1, n_out, n_out}, 0); - res_unit.conv2_s = create_tensor(tn(LLM_TENSOR_RES_CONV2_S_B5, i, bid), {1, 1, n_out}, 0); - res_unit.conv2_b = create_tensor(tn(LLM_TENSOR_RES_CONV2_B_B5, i, bid), {n_out}, 0); + auto & ru = layer.decoder_blocks[bid].res_unit; + ru.alpha1 = create_tensor(tn(LLM_TENSOR_RES_SNAKE1_A_B5, i, bid), {1, n_out, 1}, 0); + ru.conv1_w = create_tensor(tn(LLM_TENSOR_RES_CONV1_W_B5, i, bid), {7, 1, n_out}, 0); + ru.conv1_s = create_tensor(tn(LLM_TENSOR_RES_CONV1_S_B5, i, bid), {1, 1, n_out}, 0); + ru.conv1_b = create_tensor(tn(LLM_TENSOR_RES_CONV1_B_B5, i, bid), {n_out}, 0); + ru.alpha2 = create_tensor(tn(LLM_TENSOR_RES_SNAKE2_A_B5, i, bid), {1, n_out, 1}, 0); + ru.conv2_w = create_tensor(tn(LLM_TENSOR_RES_CONV2_W_B5, i, bid), {1, n_out, n_out}, 0); + ru.conv2_s = create_tensor(tn(LLM_TENSOR_RES_CONV2_S_B5, i, bid), {1, 1, n_out}, 0); + ru.conv2_b = create_tensor(tn(LLM_TENSOR_RES_CONV2_B_B5, i, bid), {n_out}, 0); } break; default: fprintf(stderr, "%s: ERROR: Unexpected block id %d in layer %d\n", __func__, bid, i); - return false; // Or handle error appropriately + return false; } fprintf(stderr, "%s: Layer %d, Block %d: Finished\n", __func__, i, bid); - } // End block loop + } } else if (i == 6) { // --- Layer 6: Alpha --- layer.alpha = create_tensor(tn(LLM_TENSOR_ALPHA, i, -1), {1, n_in, 1}, 0); @@ -3834,9 +3833,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.conv_s = create_tensor(tn(LLM_TENSOR_CONV_S7, i, -1), {1, 1, n_out}, 0); layer.conv_b = create_tensor(tn(LLM_TENSOR_CONV_B7, i, -1), {n_out}, 0); } - else { // Should not happen + else { fprintf(stderr, "%s: ERROR: Unexpected layer index %d\n", __func__, i); - return false; // Or handle error appropriately + return false; } fprintf(stderr, "%s: Layer %d: Finished\n", __func__, i); } @@ -11744,286 +11743,230 @@ struct llm_build_wavtokenizer_dec : public llm_graph_context { } }; -// struct llm_build_snac_dec : public llm_graph_context { - -// llm_build_snac_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { -// LLAMA_LOG_INFO("Raw ubatch.n_tokens = %d\n", ubatch.n_tokens); -// for (int i = 0; i < std::min(20, (int)ubatch.n_tokens); ++i) { -// LLAMA_LOG_INFO("%d ", ubatch.token[i]); -// } -// LLAMA_LOG("\n"); -// LLAMA_LOG_DEBUG("%s: Entering constructor, model.layers.size() = %zu\n", __func__, model.layers.size()); -// ggml_tensor * cur; -// ggml_tensor * inpL; - -// // TODO: probalby just get raw codes -// //cur = build_inp_embd(model.tok_embd); -// //LLAMA_LOG_INFO("After build_inp_embd: shape = [%ld, %ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); - -// // hack, hardcode expected SNAC input at first conv layer -// cur = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 768, 64, 1, 1); // [channels, seq_len, 1, 1] -// ggml_set_input(cur); -// LLAMA_LOG_INFO("hardcoded shape = [%ld, %ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); - -// // end hack - -// // Log input tokens before processing -// LLAMA_LOG_INFO("%s: ubatch.n_tokens = %u\n", __func__, ubatch.n_tokens); -// LLAMA_LOG_WARN("%s: Input tokens from ubatch = ", __func__); -// for (uint32_t i = 0; i < ubatch.n_tokens && i < 20; ++i) { -// LLAMA_LOG_INFO("%d ", ubatch.token[i]); -// } -// if (ubatch.n_tokens > 20) LLAMA_LOG_INFO("..."); -// LLAMA_LOG("\n"); - -// // ggml_tensor * layer_1; -// // ggml_tensor * layer_2; -// // ggml_tensor * layer_3; -// //redistribute_codes(cur, &layer_1, &layer_2, &layer_3); - -// // Log the redistributed layers -// //log_tensor("Layer 1", layer_1); -// //log_tensor("Layer 2", layer_2); -// //log_tensor("Layer 3", layer_3); - -// for (uint32_t il = 1; il < model.layers.size(); ++il) { -// const auto & layer = model.layers[il]; - -// LLAMA_LOG_DEBUG("%s: Layer %u: Starting, cur = %p\n", __func__, il, cur); - -// if (il == 1) { // pointwise -// LLAMA_LOG_INFO("%s: Layer %u: Pointwise conv, conv_w = %p, conv_s = %p, conv_b = %p\n", -// __func__, il, layer.conv_w, layer.conv_s, layer.conv_b); -// LLAMA_LOG_INFO("Before transpose, cur shape = [%ld, %ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); -// cur = ggml_transpose(ctx0, cur); // [768, 512] -> [512, 768] -// LLAMA_LOG_INFO("After transpose, cur shape = [%ld, %ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); -// cur = apply_conv1d(cur, layer.conv_w, layer.conv_s, layer.conv_b, 1, 0); -// LLAMA_LOG_INFO("%s: Layer %u: After pointwise conv, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); -// } else if (il == model.layers.size() - 1) { -// LLAMA_LOG_INFO("%s: Layer %u: Final layer, alpha = %p, conv_w = %p, conv_s = %p, conv_b = %p\n", -// __func__, il, layer.alpha, layer.conv_w, layer.conv_s, layer.conv_b); -// cur = ggml_snake(ctx0, cur, layer.alpha); -// LLAMA_LOG_INFO("%s: Layer %u: After ggml_snake, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); -// cur = apply_conv1d(cur, layer.conv_w, layer.conv_s, layer.conv_b, 1, 3); -// LLAMA_LOG_INFO("%s: Layer %u: After final conv, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); -// cur = ggml_tanh(ctx0, cur); -// LLAMA_LOG_INFO("%s: Layer %u: After ggml_tanh, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); -// } else { -// // Layers 2-5: Decoder Blocks (1024 -> 512 -> 256 -> 128 -> 64) -// const int stride = hparams.upsample_rates[il - 2]; // 8 for il = 2 -// const int padding = stride; - -// // Block 0: Snake activation -// const auto & block0 = layer.decoder_blocks[0]; -// LLAMA_LOG_DEBUG("%s: Layer %u: Block 0, alpha = %p\n", __func__, il, block0.alpha); -// cur = ggml_snake(ctx0, cur, block0.alpha); -// LLAMA_LOG_DEBUG("%s: Layer %u: After ggml_snake, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, cur, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); - -// // Block 1: Transposed convolution -// const auto & block1 = layer.decoder_blocks[1]; -// LLAMA_LOG_DEBUG("%s: Layer %u: Block 1, stride = %d, up_weight = %p, up_scale = %p, up_bias = %p\n", -// __func__, il, stride, block1.up_weight, block1.up_scale, block1.up_bias); - -// cur = apply_conv1d_transpose(cur, block1.up_weight, block1.up_scale, block1.up_bias, stride, padding); -// LLAMA_LOG_DEBUG("%s: Layer %u: After conv1d_transpose, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, cur, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); - -// // Residual Units (3 per block) -// for (int j = 0; j < 3; ++j) { -// const auto & ru = block1.res_units[j]; -// ggml_tensor * inpL = cur; -// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: Starting, inpL = %p, alpha1 = %p, conv1_w = %p, conv1_s = %p, conv1_b = %p\n", -// __func__, il, j, inpL, ru.alpha1, ru.conv1_w, ru.conv1_s, ru.conv1_b); - -// cur = ggml_snake(ctx0, cur, ru.alpha1); -// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: After ggml_snake (alpha1), cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, j, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); -// int dilation = (j == 0) ? 1 : (j == 1) ? 3 : 9; -// int padding = 3 * dilation; // Kernel 7, dilated padding = (7-1)/2 * dilation -// cur = apply_conv1d(cur, ru.conv1_w, ru.conv1_s, ru.conv1_b, 1, padding); -// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: After conv1d (conv1), cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, j, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); - -// // pw -// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: Pointwise, alpha2 = %p, conv2_w = %p, conv2_s = %p, conv2_b = %p\n", -// __func__, il, j, ru.alpha2, ru.conv2_w, ru.conv2_s, ru.conv2_b); -// cur = ggml_snake(ctx0, cur, ru.alpha2); -// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: After ggml_snake (alpha2), cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, j, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); -// cur = apply_conv1d(cur, ru.conv2_w, ru.conv2_s, ru.conv2_b, 1, 0); -// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: After conv1d (conv2), cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, j, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); - -// // residual -// cur = ggml_add(ctx0, cur, inpL); -// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: After ggml_add, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, j, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); -// } -// } -// LLAMA_LOG_DEBUG("%s: Layer %u: Finished, cur = %p\n", __func__, il, cur); -// } - -// int64_t target_samples = 24000; // TODO: magic number -// LLAMA_LOG_DEBUG("%s: Trimming output, cur = %p, target_samples = %ld, cur->ne[0] = %ld\n", -// __func__, cur, target_samples, cur ? cur->ne[0] : -1); -// if (cur->ne[0] > target_samples) { -// cur = ggml_get_rows(ctx0, cur, ggml_new_i32(ctx0, target_samples)); -// LLAMA_LOG_DEBUG("%s: After ggml_get_rows, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); -// } - -// LLAMA_LOG_DEBUG("%s: Setting result_embd, cur = %p\n", __func__, cur); -// cb(cur, "result_embd", -1); -// res->t_embd = cur; - -// LLAMA_LOG_DEBUG("%s: Building forward graph, cur = %p\n", __func__, cur); -// ggml_build_forward_expand(gf, cur); -// LLAMA_LOG_DEBUG("%s: Graph build completed\n", __func__); -// } - -// // TODO: move these somewhere else -// private: -// // Helper to log tensor contents -// void log_tensor(const char * name, ggml_tensor * tensor) { -// if (!tensor) { -// LLAMA_LOG_INFO("%s: %s is null\n", __func__, name); -// return; -// } -// LLAMA_LOG_DEBUG("%s: %s shape = [%ld, %ld, %ld, %ld], first 20 elements = ", -// __func__, name, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); -// int n_elements = ggml_nelements(tensor); -// float * data = (float *)tensor->data; -// for (int i = 0; i < std::min(20, n_elements); ++i) { -// LLAMA_LOG_DEBUG("%.2f ", data[i]); -// } -// if (n_elements > 20) LLAMA_LOG_DEBUG("..."); -// LLAMA_LOG_DEBUG("\n"); -// } - -// void redistribute_codes(ggml_tensor * input, ggml_tensor ** layer_1, ggml_tensor ** layer_2, ggml_tensor ** layer_3) { -// int64_t n_codes = input->ne[1]; // Assuming input is [n_embd, n_tokens, 1, 1] -// int64_t n_frames = n_codes / 7; -// if (n_codes % 7 != 0) { -// LLAMA_LOG_ERROR("%s: Input codes length %ld is not a multiple of 7\n", __func__, n_codes); -// *layer_1 = *layer_2 = *layer_3 = nullptr; -// return; -// } - -// int64_t n_layer_1 = n_frames; // 1 code per frame -// int64_t n_layer_2 = n_frames * 2; // 2 codes per frame -// int64_t n_layer_3 = n_frames * 4; // 4 codes per frame - -// // Indices for each layer -// std::vector idx_layer_1(n_layer_1); -// std::vector idx_layer_2(n_layer_2); -// std::vector idx_layer_3(n_layer_3); - -// for (int64_t i = 0; i < n_frames; ++i) { -// int64_t base_idx = i * 7; -// idx_layer_1[i] = base_idx + 0; // No offset -// idx_layer_2[i * 2] = base_idx + 1; // Offset -4096 -// idx_layer_2[i * 2 + 1] = base_idx + 4; // Offset -16384 -// idx_layer_3[i * 4] = base_idx + 2; // Offset -8192 -// idx_layer_3[i * 4 + 1] = base_idx + 3; // Offset -12288 -// idx_layer_3[i * 4 + 2] = base_idx + 5; // Offset -20480 -// idx_layer_3[i * 4 + 3] = base_idx + 6; // Offset -24576 -// } - -// // Create index tensors -// ggml_tensor * idx_1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_layer_1); -// ggml_tensor * idx_2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_layer_2); -// ggml_tensor * idx_3 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_layer_3); - -// memcpy(idx_1->data, idx_layer_1.data(), n_layer_1 * sizeof(int32_t)); -// memcpy(idx_2->data, idx_layer_2.data(), n_layer_2 * sizeof(int32_t)); -// memcpy(idx_3->data, idx_layer_3.data(), n_layer_3 * sizeof(int32_t)); - -// // Extract layers using ggml_get_rows -// *layer_1 = ggml_get_rows(ctx0, input, idx_1); -// *layer_2 = ggml_get_rows(ctx0, input, idx_2); -// *layer_3 = ggml_get_rows(ctx0, input, idx_3); - -// // Apply offsets -// *layer_2 = ggml_add(ctx0, *layer_2, ggml_new_f32(ctx0, -4096.0f)); // Simplified; we'll refine offsets later -// *layer_3 = ggml_add(ctx0, *layer_3, ggml_new_f32(ctx0, -8192.0f)); // Simplified for now -// } - -// ggml_tensor * apply_conv1d(ggml_tensor * input, ggml_tensor * conv_w, ggml_tensor * conv_scale, ggml_tensor * conv_b, -// int stride, int padding) { -// ggml_tensor * w_final = normalize_weight(conv_w, conv_scale); -// ggml_tensor * cur = ggml_conv_1d_ph(ctx0, w_final, input, stride, padding); -// if (conv_b) { -// ggml_tensor* bias_reshaped = ggml_reshape_3d(ctx0, conv_b, 1, 1024, 1); -// cur = ggml_add(ctx0, cur, bias_reshaped); -// } -// return cur; -// } - -// ggml_tensor * apply_conv1d_transpose(ggml_tensor * input, ggml_tensor * up_weight, ggml_tensor * up_scale, ggml_tensor * up_bias, int stride, int padding) { -// // Normalize weights (temporary fix for up_scale shape mismatch) -// if (up_scale->ne[2] != up_weight->ne[1]) { // 1024 != 512 -// LLAMA_LOG_WARN("up_scale channels (%ld) don’t match output channels (%ld), expected behavior may vary\n", up_scale->ne[2], up_weight->ne[1]); -// // Ideally reshape up_scale to [1, 1, 512, 1], but no reshape; proceed with warning -// } -// ggml_tensor * w_final = normalize_weight(up_weight, up_scale); -// LLAMA_LOG_INFO("After normalize weight: w_final shape = [%ld, %ld, %ld, %ld]\n", -// w_final->ne[0], w_final->ne[1], w_final->ne[2], w_final->ne[3]); - -// ggml_tensor * cur = ggml_conv_transpose_1d(ctx0, w_final, input, stride, 0, 1); -// LLAMA_LOG_INFO("After ggml_conv_transpose_1d = [%ld, %ld, %ld, %ld]\n", -// cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); - -// if (up_bias) { -// // up_bias is [512, 1, 1, 1]; need [4104, 512, 1, 1] for ggml_add -// LLAMA_LOG_INFO("entering up_bias block. Before ggml_repeat, cur shape = [%ld, %ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); -// LLAMA_LOG_INFO("Before ggml_repeat, up_bias shape = [%ld, %ld, %ld, %ld]\n", up_bias->ne[0], up_bias->ne[1], up_bias->ne[2], up_bias->ne[3]); -// ggml_tensor * bias_repeated = ggml_repeat(ctx0, up_bias, cur); -// LLAMA_LOG_DEBUG("Repeated up_bias to shape = [%ld, %ld, %ld, %ld]\n", -// bias_repeated->ne[0], bias_repeated->ne[1], bias_repeated->ne[2], bias_repeated->ne[3]); -// cur = ggml_add(ctx0, cur, bias_repeated); -// LLAMA_LOG_DEBUG("After bias add: cur shape = [%ld, %ld, %ld, %ld]\n", -// cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); -// } -// return cur; -// } - -// // w_final = scale * (w / || w ||) -// ggml_tensor * normalize_weight(ggml_tensor * w, ggml_tensor * scale) { -// ggml_tensor * norm = ggml_norm(ctx0, w, 1e-5f); // 1e-8f ? -// ggml_tensor * w_normalized = ggml_div(ctx0, w, norm); -// ggml_tensor * w_final = ggml_mul(ctx0, w_normalized, scale); -// return w_final; -// } -// }; - // TODO: Placeholder struct llm_build_snac_dec : public llm_graph_context { llm_build_snac_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + ggml_tensor * cur; + ggml_tensor * emb_layer_1, * emb_layer_2, * emb_layer_3; + build_codebook_embd(model, &emb_layer_1, &emb_layer_2, &emb_layer_3); + + if (emb_layer_1 == nullptr || emb_layer_2 == nullptr || emb_layer_3 == nullptr) { + // graph build is called with garbage ubatch codes during model init + // in this case, bypass normal graph construction and return a dummy + LLAMA_LOG_INFO("build_codebook_inputs returned null, using dummy tensor\n"); + cur = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 768, ubatch.n_tokens > 0 ? ubatch.n_tokens : 64, 1, 1); + ggml_set_input(cur); + } else { + // Projections + cur = ggml_mul_mat(ctx0, ggml_reshape_2d(ctx0, model.codebook_proj_w[0], 8, 768), emb_layer_1); + cur = ggml_reshape_4d(ctx0, cur, 768, emb_layer_1->ne[1], 1, 1); + ggml_tensor * scale_1 = ggml_reshape_4d(ctx0, model.codebook_proj_s[0], 768, 1, 1, 1); + cur = ggml_mul(ctx0, cur, scale_1); + ggml_tensor * bias_1 = ggml_reshape_4d(ctx0, model.codebook_proj_b[0], 768, 1, 1, 1); // Fix here + cur = ggml_add(ctx0, cur, bias_1); + + ggml_tensor * proj_2 = ggml_mul_mat(ctx0, ggml_reshape_2d(ctx0, model.codebook_proj_w[1], 8, 768), emb_layer_2); + proj_2 = ggml_reshape_4d(ctx0, proj_2, 768, emb_layer_2->ne[1], 1, 1); + ggml_tensor * scale_2 = ggml_reshape_4d(ctx0, model.codebook_proj_s[1], 768, 1, 1, 1); + proj_2 = ggml_mul(ctx0, proj_2, scale_2); + ggml_tensor * bias_2 = ggml_reshape_4d(ctx0, model.codebook_proj_b[1], 768, 1, 1, 1); + proj_2 = ggml_add(ctx0, proj_2, bias_2); + + ggml_tensor * proj_3 = ggml_mul_mat(ctx0, ggml_reshape_2d(ctx0, model.codebook_proj_w[2], 8, 768), emb_layer_3); + proj_3 = ggml_reshape_4d(ctx0, proj_3, 768, emb_layer_3->ne[1], 1, 1); + ggml_tensor * scale_3 = ggml_reshape_4d(ctx0, model.codebook_proj_s[2], 768, 1, 1, 1); + proj_3 = ggml_mul(ctx0, proj_3, scale_3); + ggml_tensor * bias_3 = ggml_reshape_4d(ctx0, model.codebook_proj_b[2], 768, 1, 1, 1); + proj_3 = ggml_add(ctx0, proj_3, bias_3); + + cur = ggml_concat(ctx0, cur, proj_2, 1); + cur = ggml_concat(ctx0, cur, proj_3, 1); + + for (int j = 1; j <= hparams.n_layer; ++j) { + const auto & layer = model.layers[j]; + const int64_t n_in = hparams.n_channels[j-1]; + const int64_t n_out = (j < 7) ? hparams.n_channels[j] : hparams.n_channels[j-1]; + + if (j == 1) { + int64_t seq_len = cur->ne[1]; + cur = ggml_reshape_2d(ctx0, cur, 768, seq_len); // cur starts F32 (type 0) from projections + ggml_tensor * w = ggml_reshape_2d(ctx0, layer.conv_w, 768, 1024); // F16 (type 1) + ggml_tensor * s = ggml_cpy(ctx0, layer.conv_s, ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, 1, n_out)); // Cast F32 -> F16 + w = ggml_mul(ctx0, w, s); + cur = ggml_mul_mat(ctx0, w, cur); + cur = ggml_reshape_4d(ctx0, cur, seq_len, 1024, 1, 1); + ggml_tensor * b = ggml_reshape_4d(ctx0, layer.conv_b, 1, n_out, 1, 1); + cur = ggml_add(ctx0, cur, b); + } + // Residual Units + else if (j >= 2 && j <= 5) { + ggml_tensor * alpha = layer.decoder_blocks[0].alpha; + cur = ggml_snake(ctx0, cur, alpha); + + ggml_tensor * w = layer.decoder_blocks[1].up_weight; + ggml_tensor * s = ggml_cpy(ctx0, layer.decoder_blocks[1].up_scale, + ggml_new_tensor_4d(ctx0, GGML_TYPE_F16, 1, 1, n_in, 1)); + w = ggml_mul(ctx0, w, s); + cur = ggml_conv_transpose_1d(ctx0, w, cur, hparams.upsample_rates[j-2], 0, 1); + ggml_tensor * b = ggml_reshape_4d(ctx0, layer.decoder_blocks[1].up_bias, 1, n_out, 1, 1); + cur = ggml_add(ctx0, cur, b); + + ggml_tensor * noise_w = layer.decoder_blocks[2].noise_w; + ggml_tensor * noise_s = ggml_cpy(ctx0, layer.decoder_blocks[2].noise_s, + ggml_new_tensor_4d(ctx0, GGML_TYPE_F16, 1, 1, n_out, 1)); + noise_w = ggml_mul(ctx0, noise_w, noise_s); + cur = ggml_conv_1d(ctx0, noise_w, cur, 1, 0, 1); + + for (int r = 0; r < 3; ++r) { + int bid = 3 + r; + ggml_tensor * w1 = layer.decoder_blocks[bid].res_unit.conv1_w; + ggml_tensor * s1 = ggml_cpy(ctx0, layer.decoder_blocks[bid].res_unit.conv1_s, + ggml_new_tensor_4d(ctx0, GGML_TYPE_F16, 1, 1, n_out, 1)); + w1 = ggml_mul(ctx0, w1, s1); + cur = ggml_conv_1d_dw(ctx0, w1, cur, 1, 3, 1); + ggml_tensor * b1 = ggml_reshape_4d(ctx0, layer.decoder_blocks[bid].res_unit.conv1_b, 1, n_out, 1, 1); + cur = ggml_add(ctx0, cur, b1); + + ggml_tensor * w2 = layer.decoder_blocks[bid].res_unit.conv2_w; + ggml_tensor * s2 = ggml_cpy(ctx0, layer.decoder_blocks[bid].res_unit.conv2_s, + ggml_new_tensor_4d(ctx0, GGML_TYPE_F16, 1, 1, n_out, 1)); + w2 = ggml_mul(ctx0, w2, s2); + cur = ggml_conv_1d(ctx0, w2, cur, 1, 0, 1); + ggml_tensor * b2 = ggml_reshape_4d(ctx0, layer.decoder_blocks[bid].res_unit.conv2_b, 1, n_out, 1, 1); + cur = ggml_add(ctx0, cur, b2); + } + } + else if (j == 6) { + ggml_tensor * alpha = layer.alpha; + cur = ggml_snake(ctx0, cur, alpha); + } + else if (j == 7) { + ggml_tensor * w = layer.conv_w; + ggml_tensor * s = layer.conv_s; + + s = ggml_reshape_4d(ctx0, s, 1, 1, 1, 1); + s = ggml_cpy(ctx0, s, ggml_new_tensor_4d(ctx0, GGML_TYPE_F16, 1, 1, 1, 1)); + w = ggml_mul(ctx0, w, s); + cur = ggml_conv_1d(ctx0, w, cur, 1, 3, 1); + + ggml_tensor * b = ggml_reshape_4d(ctx0, layer.conv_b, 1, 1, 1, 1); + cur = ggml_add(ctx0, cur, b); + } + } - // TODO: Remove - LLAMA_LOG_INFO("Raw ubatch.n_tokens = %d\n", ubatch.n_tokens); - for (int i = 0; i < std::min(20, (int)ubatch.n_tokens); ++i) { - LLAMA_LOG_INFO("%d ", ubatch.token[i]); } - LLAMA_LOG("\n"); - ggml_tensor * cur; - // TODO: Hack. Implement codebook lookups and out_proj - cur = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 768, 64, 1, 1); - ggml_set_input(cur); - // end hack + cur = ggml_cpy(ctx0, cur, ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3])); - LLAMA_LOG_DEBUG("%s: Setting result_embd, cur = %p\n", __func__, cur); cb(cur, "result_embd", -1); res->t_embd = cur; ggml_build_forward_expand(gf, cur); } +private: + // TODO: SNAC expects a multilayered input from 3 different embedding matrices + void build_codebook_embd(const llama_model & model, + ggml_tensor ** emb_layer_1, + ggml_tensor ** emb_layer_2, + ggml_tensor ** emb_layer_3) { + + *emb_layer_1 = nullptr; + *emb_layer_2 = nullptr; + *emb_layer_3 = nullptr; + + + + bool is_initialized = (ubatch.token != nullptr && ubatch.n_tokens > 0); + if (is_initialized) { + for (int i = 0; i < ubatch.n_tokens; ++i) { + if (ubatch.token[i] < 0 || ubatch.token[i] >= 4096) { + is_initialized = false; + break; + } + } + } + + if (!is_initialized) { + return; + } + + int32_t n_tokens = ubatch.n_tokens; + int32_t n_frames = n_tokens / 7; + if (n_tokens % 7 != 0) { + LLAMA_LOG_INFO("build_codebook_embd: n_tokens (%d) not a multiple of 7, truncating\n", n_tokens); + n_frames = n_tokens / 7; + } + + // TODO: read from vq_strides + int32_t n_layer_1 = n_frames; + int32_t n_layer_2 = n_frames * 2; + int32_t n_layer_3 = n_frames * 4; + + LLAMA_LOG_INFO("build_codebook_embd: n_frames = %d, n_layer_1 = %d, n_layer_2 = %d, n_layer_3 = %d\n", + n_frames, n_layer_1, n_layer_2, n_layer_3); + + std::vector idx_1_data(n_layer_1); + std::vector idx_2_data(n_layer_2); + std::vector idx_3_data(n_layer_3); + + // map codes to respective codebook + for (int32_t i = 0; i < n_frames; ++i) { + int32_t base_idx = i * 7; + idx_1_data[i] = ubatch.token[base_idx + 0]; + idx_2_data[i * 2] = ubatch.token[base_idx + 1]; + idx_2_data[i * 2 + 1] = ubatch.token[base_idx + 4]; + idx_3_data[i * 4] = ubatch.token[base_idx + 2]; + idx_3_data[i * 4 + 1] = ubatch.token[base_idx + 3]; + idx_3_data[i * 4 + 2] = ubatch.token[base_idx + 5]; + idx_3_data[i * 4 + 3] = ubatch.token[base_idx + 6]; + } + + // Tensors used for codebook lookups + ggml_tensor * idx_layer_1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_layer_1); + ggml_tensor * idx_layer_2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_layer_2); + ggml_tensor * idx_layer_3 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_layer_3); + + if (!idx_layer_1 || !idx_layer_2 || !idx_layer_3) { + LLAMA_LOG_INFO("build_codebook_embd: Failed to allocate index tensors\n"); + return; + } + + // ggml is lazy, so explicitly create buffers for codes to be placed in idx_layer_N + ggml_backend_buffer_type_t cpu_buft = ggml_backend_cpu_buffer_type(); + if (!cpu_buft) { + LLAMA_LOG_ERROR("build_codebook_embd: Failed to get CPU buffer type\n"); + return; + } + + ggml_backend_buffer_t buffer_1 = ggml_backend_buft_alloc_buffer(cpu_buft, n_layer_1 * sizeof(int32_t)); + ggml_backend_buffer_t buffer_2 = ggml_backend_buft_alloc_buffer(cpu_buft, n_layer_2 * sizeof(int32_t)); + ggml_backend_buffer_t buffer_3 = ggml_backend_buft_alloc_buffer(cpu_buft, n_layer_3 * sizeof(int32_t)); + + if (!buffer_1 || !buffer_2 || !buffer_3) { + LLAMA_LOG_ERROR("build_codebook_embd: Failed to allocate backend buffers\n"); + if (buffer_1) ggml_backend_buffer_free(buffer_1); + if (buffer_2) ggml_backend_buffer_free(buffer_2); + if (buffer_3) ggml_backend_buffer_free(buffer_3); + return; + } + + // move codes to idx_layer_N + idx_layer_1->buffer = buffer_1; + idx_layer_2->buffer = buffer_2; + idx_layer_3->buffer = buffer_3; + + idx_layer_1->data = ggml_backend_buffer_get_base(buffer_1); + idx_layer_2->data = ggml_backend_buffer_get_base(buffer_2); + idx_layer_3->data = ggml_backend_buffer_get_base(buffer_3); + + ggml_backend_tensor_set(idx_layer_1, idx_1_data.data(), 0, n_layer_1 * sizeof(int32_t)); + ggml_backend_tensor_set(idx_layer_2, idx_2_data.data(), 0, n_layer_2 * sizeof(int32_t)); + ggml_backend_tensor_set(idx_layer_3, idx_3_data.data(), 0, n_layer_3 * sizeof(int32_t)); + + *emb_layer_1 = ggml_get_rows(ctx0, model.codebook[0], idx_layer_1); + *emb_layer_2 = ggml_get_rows(ctx0, model.codebook[1], idx_layer_2); + *emb_layer_3 = ggml_get_rows(ctx0, model.codebook[2], idx_layer_3); + } }; llama_memory_i * llama_model::create_memory() const { diff --git a/src/llama-model.h b/src/llama-model.h index 5e636b0b3b3f3..e75bcf1ed8887 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -156,7 +156,7 @@ struct llama_layer_snac_dec_block { struct ggml_tensor * conv2_w = nullptr; struct ggml_tensor * conv2_s = nullptr; struct ggml_tensor * conv2_b = nullptr; - } res_units[3]; + } res_unit; }; struct llama_layer { @@ -328,7 +328,7 @@ struct llama_layer { struct llama_layer_convnext convnext; struct ggml_tensor * conv_w = nullptr; - struct ggml_tensor * conv_s = nullptr; + struct ggml_tensor * conv_s = nullptr; struct ggml_tensor * conv_b = nullptr; struct ggml_tensor * alpha = nullptr;