Source code for textpruner.pruners.pipeline_pruner

from .transformer_pruner import TransformerPruner
from .vocabulary_pruner import VocabularyPruner
from typing import Optional
from ..configurations import GeneralConfig,VocabularyPruningConfig,TransformerPruningConfig
import torch
from torch import nn
import os
import logging
logger = logging.getLogger(__name__)
from .utils import infer_model_type
from ..model_map import MODEL_MAP

[docs]class PipelinePruner: ''' Args: model : The model to be pruned. tokenizer : The tokenizer for the model. vocabulary_pruning_config : a :class:`~textpruner.configurations.VocabularyPruningConfig` object. transformer_pruning_config : a :class:`~textpruner.configurations.TransformerPruningConfig` object. general_config : a :class:`~textpruner.configurations.GeneralConfig` object. base_model_prefix : The prefix of the base model, i.e., the name of the base model as a member in the model. \ For example, if ``model.bert_encoder = BertModel(...)``, then the ``base_model_prefix`` is ``bert_encoder``. \ TextPruner will infer the ``base_model_prefix`` so we can leave its value as ``None``. But if it fails, users have to set its value explicitly. ''' def __init__(self, model: nn.Module, tokenizer, transformer_pruning_config: Optional[TransformerPruningConfig] = None, vocabulary_pruning_config : Optional[VocabularyPruningConfig] = None, general_config: Optional[GeneralConfig] = None, base_model_prefix : Optional[str] = None): self.model = model self.tokenizer = tokenizer self.general_config = GeneralConfig() if general_config is None else general_config self.transformer_pruning_config = TransformerPruningConfig() if transformer_pruning_config is None else transformer_pruning_config self.vocabulary_pruning_config = VocabularyPruningConfig() if vocabulary_pruning_config is None else vocabulary_pruning_config self.output_dir = self.general_config.output_dir base_model, model_type = infer_model_type(model, base_model_prefix) assert model_type in MODEL_MAP, \ f"Model type {self.model_type} is not supported, or not understood. Model type must be one of {list(MODEL_MAP.keys())}" self.base_model = base_model self.model_type = model_type self.vocabulary_pruner = VocabularyPruner(model, tokenizer, vocabulary_pruning_config, general_config, base_model_prefix=base_model_prefix) self.transformer_pruner = TransformerPruner(model, transformer_pruning_config, general_config, base_model_prefix=base_model_prefix) self.save_dir = None
[docs] def prune(self, dataloader=None, adaptor=None, batch_postprocessor=None, head_mask: Optional[torch.Tensor] =None, ffn_mask: Optional[torch.Tensor]=None, keep_shape=False, dataiter=None, additional_tokens=None, additional_token_ids=None, save_model=True) -> Optional[str]: ''' Prunes the transformers, then prunes the vocabulary. Args: dataloader : a dataloader that generates batches. Each batch should contains both the inputs and the labels. adaptor : a function that takes the model output and return the loss. batch_postprocessor : a function that takes the batch produced by the dataloader and return a batch. It is used for post-processing the batches if needed. head_mask : a tensor of shape ``(num_layers, num_attention_heads)``. `1` means to keep, `0` means to prune. ffn_mask : a tensor of shape ``(num_layers, intermediate_hidden_size)``. `1` means to keep, `0` means to prune. keep_shape : if ``True``, the model is no actually pruned and the model stucture is not changed, but the weights that *should be pruned* are set to zero. dataiter : a list of pre-tokenized strings. These strings will be tokenized by the tokenizer to generate a set of tokens. additional_tokens : a list of tokens. These tokens must be existed in the original vocabulary. additional_token_ids : a list of ints representing the token ids. save_model : whether to save the model when the pruning is finished. ''' logger.info("Transfomer pruning...") self.transformer_pruner.prune(dataloader, adaptor, batch_postprocessor=batch_postprocessor, keep_shape=keep_shape, head_mask=head_mask, ffn_mask=ffn_mask, save_model=False) logger.info("Vocabulary pruning...") self.vocabulary_pruner.prune(dataiter=dataiter, additional_tokens=additional_tokens, additional_token_ids=additional_token_ids, save_model=False) if save_model is True: self.save_dir = self.save_model() return self.save_dir
def save_model(self, dir_name=None) -> str: ffn_sizes = self.transformer_pruner.ffn_mask.to(int).sum(-1).tolist() if self.transformer_pruner.keep_shape is False: ffn_size = ffn_sizes[0] num_of_heads = self.transformer_pruner.head_mask.sum().item() / self.transformer_pruner.head_mask.size(0) if len(set(ffn_sizes)) != 1: raise NotImplementedError("Cannot save pruned model with different ffn size per layer with keep_shape=False. \ Call PipelinePruner.save_masks or PipelinePruner.save_jit_model manually instead.") else: self.base_model.config.intermediate_size = ffn_size else: ffn_size = self.transformer_pruner.ffn_mask.size(1) #base_model.config.intermediate_size num_of_heads = self.transformer_pruner.head_mask.size(1) #self.transformer_pruning_config.target_num_of_heads vocab_size = len(self.vocabulary_pruner.pruned_token_ids) self.base_model.config.vocab_size = vocab_size if dir_name is None: save_dir = os.path.join(self.general_config.output_dir,f'pruned_V{vocab_size}H{num_of_heads}F{ffn_size}') else: save_dir = os.path.join(self.general_config.output_dir,dir_name) os.makedirs(save_dir, exist_ok=True) torch.save(self.model.state_dict(),os.path.join(save_dir,'pytorch_model.bin')) # save config self.base_model.config.save_pretrained(save_dir) # save tokenizer self.vocabulary_pruner.tokenizer_helper.save_vocab(self.tokenizer, self.vocabulary_pruner.pruned_token_ids, save_dir) logger.info(f"Model and configuration have been saved to {save_dir}") return save_dir def save_jit_model(self, example_inputs, dir_name=None) -> str: self.model.eval() with torch.no_grad(): traced_model = torch.jit.trace(self.model, example_inputs=example_inputs, strict=False) if dir_name is None: save_dir = os.path.join(self.general_config.output_dir,'pruned_H{num_of_heads}F{ffn_size}_traced') else: save_dir = os.path.join(self.general_config.output_dir,dir_name) os.makedirs(save_dir, exist_ok=True) torch.jit.save(traced_model, os.path.join(save_dir,'pytorch_model.ts')) return save_dir