From 8da8305c1815e65052eae0be189a9b6d628993ee Mon Sep 17 00:00:00 2001 From: haukurb Date: Tue, 7 Nov 2023 17:28:22 +0000 Subject: [PATCH] Added uncertainty estimation scripts for ByT5 architecture models --- src/greynirseq/utils/infer_file_byt5.py | 101 +++++++ .../utils/score_byt5_uncertainty.py | 252 ++++++++++++++++++ 2 files changed, 353 insertions(+) create mode 100644 src/greynirseq/utils/infer_file_byt5.py create mode 100644 src/greynirseq/utils/score_byt5_uncertainty.py diff --git a/src/greynirseq/utils/infer_file_byt5.py b/src/greynirseq/utils/infer_file_byt5.py new file mode 100644 index 00000000..378dea8b --- /dev/null +++ b/src/greynirseq/utils/infer_file_byt5.py @@ -0,0 +1,101 @@ +""" +""" +import copy +import itertools +import time +from pathlib import Path + +import click +import torch +from transformers import AutoTokenizer, T5ForConditionalGeneration + +MAX_LENGTH = 512 + + +def batch_by_n(iterable, batch_size): + # collect into batches of n items, yielding after each batch + _iter = iter(iterable) + while True: + # the list call is necessary since islice is a lazily evaluated generator + batch = list(itertools.islice(_iter, batch_size)) + if not batch: + break + yield batch + + +# fmt: off +@click.command() +@click.option("--checkpoint-path", type=click.Path(exists=True, path_type=Path), required=True) +@click.option("--input-path", type=click.Path(exists=True, path_type=Path), required=True) +@click.option("--output-path", type=click.Path(exists=False, path_type=Path), required=True) +@click.option("--use-cpu", is_flag=True, default=False) +@click.option("--num-beams", type=int, default=5) +@click.option("--dropout", type=float, default=0.0) +@click.option("--batch-size", type=int, default=64) +def main(checkpoint_path, input_path, output_path, use_cpu, num_beams, dropout, batch_size): + # fmt: on + use_gpu = not use_cpu + if use_gpu: + assert torch.cuda.is_available() + + tokenizer = AutoTokenizer.from_pretrained("google/byt5-base") + # this may require not being in eval mode (i.e. computation graph is cached for gradient) + + model = T5ForConditionalGeneration.from_pretrained(str(checkpoint_path)) + if dropout > 0.0: + # Inference-time dropout. + # To make sure the correct dropout value is propagated properly, + # we do a "round trip" by reconstructing it with a modified config. + config = copy.deepcopy(model.config) + config.dropout_rate = dropout + model = T5ForConditionalGeneration.from_pretrained(str(checkpoint_path), config=config) + model = model.train() + else: + model = model.eval() + model = model.to("cuda").half() if use_gpu else model + + def prepare_model_inputs(data_lines): + model_inputs = tokenizer( + data_lines, + truncation=True, + padding=True, + max_length=MAX_LENGTH, + return_tensors="pt", + ) + if use_gpu: + model_inputs.to("cuda") + return model_inputs + + with open(input_path) as f: + # strip is important + test_lines = [line.strip() for line in f] + + # parameters for model.generate + default_generate_kwargs = { + "max_length": MAX_LENGTH, + "num_beams": num_beams, + "output_scores": True, + "return_dict_in_generate": False, + } + + start_time = time.time() + # make sure output dir exists + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w") as out_f: + for batch_lines in batch_by_n(test_lines, batch_size=batch_size): + model_inputs = prepare_model_inputs(batch_lines) + + model_output = model.generate(**model_inputs, **default_generate_kwargs) + + decoded_output = tokenizer.batch_decode( + model_output, skip_special_tokens=True + ) + for line in decoded_output: + out_f.write(line + "\n") + + end_time = time.time() + print(f"Finished {output_path} in {end_time - start_time:.2f} seconds") + + +if __name__ == "__main__": + main() diff --git a/src/greynirseq/utils/score_byt5_uncertainty.py b/src/greynirseq/utils/score_byt5_uncertainty.py new file mode 100644 index 00000000..013d8417 --- /dev/null +++ b/src/greynirseq/utils/score_byt5_uncertainty.py @@ -0,0 +1,252 @@ +""" +Script that performs dropout Monte-Carlo uncertainty estimation with a ByT5 model from checkpoint file. + +The output is a JSONL file with the following (line-wise) format: +{ + "mean": [float, float, ...], + "std": [float, float, ...], + "source": "string", + "output": "string", +} +""" +import itertools +import json +from pathlib import Path + +import click +import torch +import tqdm +from transformers import AutoTokenizer, T5ForConditionalGeneration + +MAX_LENGTH = 512 + + +def batch_by_n(iterable, batch_size): + # collect into batches of n items, yielding after each batch + _iter = iter(iterable) + while True: + # the list call is necessary since islice is a lazily evaluated generator + batch = list(itertools.islice(_iter, batch_size)) + if not batch: + break + yield batch + + +def set_t5_dropout(model: torch.nn.Module, dropout: float): + """ + Adjust the internal state of dropout modules in a T5 huggingface model. + + Storing the computation tree for backpropagation (called "training mode") more + than quadruples the CUDA memory needed for a given batch! + But the torch dropout modules have no "do-inference-at-test-time" hyperparameter, + but we can produce that behavior by changing the internal state ourselves, i.e. + only partially activate the "training mode" of the dropout modules. + This does not mean that the memory usage is the same, since generation: is + - is guaranteed to discard all of the computation tree + - discards the decoder hidden states (only caches k and v values for the attention instead) + - if the batch is unbalanced in terms of padding, the decoding prunes the short sequences + and they don't cause the same memory overhead as the longest sequences like during training. + """ + assert 0 <= dropout <= 1 + # flattened recursive traversal + for mod in model.modules(): + if not isinstance(mod, torch.nn.Dropout): + continue + mod.p = dropout + if dropout > 0: + # Save some memory + mod.inplace = True + # Calling mod.train() causes the computation tree to be stored, which we don't want. + # This may not be guaranteed by the module "Interface" (in java terminology), + # but at least in current version of torch (tested on CUDA) + # this causes dropout to be performed without storing most or none of the + # computation tree needed for backpropagation. + mod.training = True + assert mod.training + else: + # mod.inplace = False + mod.eval() + assert not mod.training + + +# fmt: off +@click.command() +@click.option( "--checkpoint-path", type=click.Path(exists=True, path_type=Path), required=True) +@click.option( "--input-path", type=click.Path(exists=True, path_type=Path), required=True) +@click.option( "--output-path", type=click.Path(exists=False, path_type=Path), required=True) +@click.option("--use-cpu", is_flag=True, default=False) +@click.option("--dropout", type=float, default=0.1) +@click.option("--seed", type=int, default=1) +@click.option("--num-iter", type=int, default=5) +@click.option("--batch-size", type=int, default=64) +def main( + checkpoint_path, + input_path, + output_path, + use_cpu, + dropout, + seed, + num_iter, + batch_size, +): + # fmt: on + # (assuming max_seqlen=512 tokens on 40GB A100) + # batch size of 768 works for generation, but 14-16 is max for scoring + # however, using torch.no_grad allows scoring with batch size of + # at least 500 (but less than 768). + # Using such high batch size is extremely suboptimal unless + # the sequence lengths in a batch are homogeneous. We can make + # the batches more uniform by sorting the input by length prior to batching. + # [1000 samples total] @ 64 bsz got 69.3 sec with sorting + # 32 bsz got 83.6 sec with sorting + # 128 bsz got 67.4 sec with sorting + # 256 bsz got 79.4 sec with sorting + + print(f"Writing to {output_path} with dropout={dropout} and num_iter={num_iter}") + use_gpu = not use_cpu + if use_gpu: + assert torch.cuda.is_available() + + tokenizer = AutoTokenizer.from_pretrained("google/byt5-base") + + print("Loading model...") + # We are using the byte-level version, ByT5 (which is implemented using the same class) + model = T5ForConditionalGeneration.from_pretrained(str(checkpoint_path)) + model = model.to("cuda").half() if use_gpu else model + + def prepare_model_inputs(data_lines): + model_inputs = tokenizer( + data_lines, + truncation=True, + padding=True, + max_length=MAX_LENGTH, + return_tensors="pt", + ) + if use_gpu: + model_inputs = model_inputs.to("cuda") + return model_inputs + + with open(input_path) as f: + # strip is important + src_lines = [line.strip() for line in f] + + # NOTE: this sorts the whole file, it would be premature optimization to do this with + # in a cleverer manner such as incremental partial sorting where we separately sort the + # first 10k, then sort 10k-20k, etc. + seqlens = torch.tensor([len(line.encode("utf-8")) for line in src_lines]).long() + sorted_indices = seqlens.argsort(descending=True) + src_lines = [src_lines[i] for i in sorted_indices.tolist()] + + # parameters for model.generate + default_generate_kwargs = { + "max_length": MAX_LENGTH, + "num_beams": 1, + "output_scores": True, + "return_dict_in_generate": False, + } + + print("Generating...") + model = model.eval() + permuted_means = [] + permuted_stds = [] + decoded_outputs = [] + with torch.no_grad(): + for batch_idx, batch_lines in enumerate( + batch_by_n(tqdm.tqdm(src_lines), batch_size=batch_size) + ): + # this call is very cheap + set_t5_dropout(model, dropout=0.0) + model_inputs = prepare_model_inputs(batch_lines) + + # generate hypothesis in eval mode + model_outputs = model.generate(**model_inputs, **default_generate_kwargs) + decoded_output = tokenizer.batch_decode( + model_outputs, skip_special_tokens=True + ) + decoded_outputs.extend(decoded_output) + + set_t5_dropout(model, dropout=dropout) + + # score the hypothesis with inference-time dropout + iter_scores = [] + src_ids = model_inputs["input_ids"] + # we need to pass the attention mask to the model (encoder attention mask) + encoder_attention_mask = model_inputs["attention_mask"] + tgt_ids = prepare_model_inputs(decoded_output)["input_ids"] + cpu = torch.device("cpu") + tgt_ids_cpu = tgt_ids.to(cpu) + # after we get scores we move the rest of the work to the cpu + decoder_target_mask = tgt_ids_cpu.eq(tokenizer.pad_token_id).logical_not() + # we need to know how many tokens are in each sequence (to filter out padding tokens) + lens = decoder_target_mask.sum(dim=1) + # since the trailing batch may be smaller than batch_size + for iteration_idx in range(num_iter): + # for reproducibility + batch_iter_seed = hash((seed, batch_idx, iteration_idx)) + _rng_gen = torch.manual_seed(batch_iter_seed) + + # get one monte-carlo iteration of scores with dropout + output = model( + input_ids=src_ids, + labels=tgt_ids, + attention_mask=encoder_attention_mask, + use_cache=True, + ) + unnormalized_score = output.logits.detach() + del output + # normalize the scores to get a (log) probability distribution + # and move to cpu (so the reference counting can free it sooner) + score = ( + unnormalized_score.log_softmax(-1).detach().clone().float().to(cpu) + ) + + # Select out the scores of the tokens in the sequence (we get scores for the whole vocab inventory) + # We have: + # - array of indices I_ij + # - array of floats S_ijk + # We want output matrix O_ij after using I as an index into S as follows: + # O_ij = S_{i}{j}{I_ij} + scores_without_padding = score.gather( + dim=2, index=tgt_ids_cpu.unsqueeze(-1) + ).squeeze(-1) + # filter out padding tokens + scores_without_padding_tuples = scores_without_padding[ + decoder_target_mask + ].split(lens.tolist()) + iter_scores.append(scores_without_padding_tuples) + + # collect, reduce and store the scores + actual_batch_size, _ = tgt_ids_cpu.shape + for seq_index in range(actual_batch_size): + seq_scores = [ + iter_scores[iter_index][seq_index] for iter_index in range(num_iter) + ] + seq_scores = torch.stack(seq_scores, dim=0) + seq_means = seq_scores.mean(dim=0) + seq_stds = seq_scores.std(dim=0) + permuted_means.append(seq_means) + permuted_stds.append(seq_stds) + + # # recover the original order by inverting the length-sorted indices + inverse_indices = sorted_indices.argsort().tolist() + permuted_means = [permuted_means[i] for i in inverse_indices] + permuted_stds = [permuted_stds[i] for i in inverse_indices] + decoded_outputs = [decoded_outputs[i] for i in inverse_indices] + + # write results to disk + with open(output_path, "w") as f: + for mean, std, src_line, decoded_output in zip( + permuted_means, permuted_stds, src_lines, decoded_outputs + ): + obj = { + "mean": mean.tolist(), + "std": std.tolist(), + "source": src_line, + "output": decoded_output, + } + f.write(json.dumps(obj, ensure_ascii=False)) + f.write("\n") + + +if __name__ == "__main__": + main()