Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import contexttimer
from colorama import Fore, Style
from transformers import AutoTokenizer, AutoModelForCausalLM

import time
from sampling import autoregressive_sampling, speculative_sampling, speculative_sampling_v2
from globals import Decoder

Expand Down Expand Up @@ -95,32 +95,40 @@ def generate(input_text, approx_model_name, target_model_name, num_tokens=20, ga
top_p = 0.9

torch.manual_seed(123)
st=time.time()
output = autoregressive_sampling(input_ids, large_model, num_tokens, top_k = top_k, top_p=top_p)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
color_print(f"large (target) model autoregressive_sampling: {generated_text}")
tot=time.time()-st
color_print(f"large (target) model autoregressive_sampling {tot:.4f}s: {generated_text}")

if use_benchmark:
benchmark(autoregressive_sampling, "AS_large", use_profiling,
input_ids, large_model, num_tokens, top_k = top_k, top_p=top_p)

torch.manual_seed(123)
st=time.time()
output = autoregressive_sampling(input_ids, small_model, num_tokens, top_k = top_k, top_p=top_p)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
color_print(f"small (approx) model autoregressive_sampling: {generated_text}")
tot=time.time()-st
color_print(f"small (approx) model autoregressive_sampling {tot:.4f}s: {generated_text}")

if use_benchmark:
benchmark(autoregressive_sampling, "AS_small", use_profiling,
input_ids, small_model, num_tokens, top_k = top_k, top_p=top_p)

torch.manual_seed(123)
st=time.time()
output = speculative_sampling_v2(input_ids, small_model, large_model, num_tokens, top_k = top_k, top_p=top_p, random_seed = random_seed)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
color_print(f"deepmind's speculative_sampling: {generated_text}")
tot=time.time()-st
color_print(f"deepmind's speculative_sampling {tot:.4f}s: {generated_text}")

torch.manual_seed(123)
st=time.time()
output = speculative_sampling(input_ids, small_model, large_model, num_tokens, gamma = gamma, top_k = top_k, top_p=top_p, random_seed = random_seed, verbose = verbose)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
color_print(f"google's speculative_sampling: {generated_text}")
tot=time.time()-st
color_print(f"google's speculative_sampling {tot:.4f}s: {generated_text}")

if use_benchmark:
benchmark(speculative_sampling, "SP", use_profiling,
Expand Down