| | import json |
| | from argparse import ArgumentParser |
| |
|
| | import datasets |
| | import torch |
| | import transformers |
| | from transformers import AutoModelForCausalLM, BatchEncoding |
| |
|
| | """ |
| | Usage examples (with the best batch sizes on A100-80GB-400W) |
| | ============================================================ |
| | python -m benchmark_hf_model --model_name_or_path="Deci/DeciLM-7B" --batch_size=352 |
| | python -m benchmark_hf_model --model_name_or_path="mistralai/Mistral-7B-v0.1" --batch_size=192 --model_kwargs_json='{"use_flash_attention_2": true}' |
| | python -m benchmark_hf_model --model_name_or_path="meta-llama/Llama-2-7b-hf" --batch_size=48 --model_kwargs_json='{"use_flash_attention_2": true}' |
| | """ |
| |
|
| |
|
| | def parse_args(): |
| | parser = ArgumentParser() |
| |
|
| | parser.add_argument( |
| | "--model_name_or_path", |
| | type=str, |
| | required=True, |
| | ) |
| | parser.add_argument( |
| | "--warmup_iters", |
| | type=int, |
| | default=10, |
| | ) |
| | parser.add_argument( |
| | "--iterations", |
| | type=int, |
| | default=5, |
| | ) |
| | parser.add_argument( |
| | "--batch_size", |
| | type=int, |
| | default=32, |
| | ) |
| | parser.add_argument( |
| | "--prompt_length", |
| | type=int, |
| | default=512, |
| | ) |
| | parser.add_argument( |
| | "--max_new_tokens", |
| | type=int, |
| | default=512, |
| | ) |
| | parser.add_argument( |
| | "--precision", |
| | type=str, |
| | default="bf16", |
| | help="Model precision, from: fp32, fp16 or bf16", |
| | ) |
| | parser.add_argument( |
| | "--model_kwargs_json", |
| | type=str, |
| | default=None, |
| | ) |
| | return parser.parse_args() |
| |
|
| |
|
| | def main(): |
| | args = parse_args() |
| | transformers.logging.set_verbosity_error() |
| | datasets.logging.set_verbosity_error() |
| |
|
| | dict_precisions = { |
| | "fp32": torch.float32, |
| | "fp16": torch.float16, |
| | "bf16": torch.bfloat16, |
| | } |
| | if args.precision not in dict_precisions: |
| | raise ValueError( |
| | f"Non valid precision {args.precision}, choose from: fp16, fp32, bf16" |
| | ) |
| | dtype = dict_precisions[args.precision] |
| |
|
| | model_kwargs = {} |
| | if args.model_kwargs_json is not None: |
| | model_kwargs = json.loads(args.model_kwargs_json) |
| |
|
| | print(f"loading model...") |
| | model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, trust_remote_code=True, |
| | torch_dtype=dtype, **model_kwargs) |
| | try: |
| | print(model.model.layers[0].self_attn) |
| | except: |
| | print("couldn't print the model's attention module") |
| |
|
| | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) |
| | model.cuda() |
| | model.eval() |
| |
|
| | prompt = torch.ones(args.prompt_length, dtype=torch.long) |
| | inputs = BatchEncoding({"input_ids": prompt.repeat(args.batch_size, 1)}) |
| | inputs = inputs.to(model.device) |
| |
|
| | |
| | print(f"warming up for {args.warmup_iters} iterations...") |
| | for _ in range(args.warmup_iters): |
| | with torch.no_grad(): |
| | _ = model.generate( |
| | **inputs, |
| | max_new_tokens=1, |
| | do_sample=False, |
| | eos_token_id=-1234, |
| | ) |
| | print('finished warmup') |
| | torch.cuda.synchronize() |
| |
|
| | print( |
| | f"prefill ({args.prompt_length} tokens{f' x {args.batch_size} batch' if args.batch_size > 1 else ''}) + generation ({args.max_new_tokens} tokens{f' x {args.batch_size} batch' if args.batch_size > 1 else ''}):") |
| | tokens_generated = args.max_new_tokens * args.batch_size |
| | prefill_and_generation = [] |
| | for gen_iter in range(args.iterations): |
| | starter.record() |
| | with torch.no_grad(): |
| | _ = model.generate( |
| | **inputs, |
| | max_new_tokens=args.max_new_tokens, |
| | do_sample=False, |
| | eos_token_id=-1234, |
| | ) |
| | ender.record() |
| | torch.cuda.synchronize() |
| | t = starter.elapsed_time(ender) / 1000 |
| | prefill_and_generation.append(t) |
| | print(f" iter {gen_iter + 1}: {t:.03f} sec total, {tokens_generated / t:.02f} generated tokens/sec") |
| | aver = sum(prefill_and_generation) / len(prefill_and_generation) |
| | print(f" average: {aver:.03f} sec total, {tokens_generated / aver:.02f} generated tokens/sec") |
| | print(f"These results are obtained for model '{args.model_name_or_path}' with {args.batch_size=}.") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|