Welcome to TextPruner’s documentation

_images/banner.png

TextPruner is a toolkit for pruning pre-trained transformer-based language models written in PyTorch. It offers structured training-free pruning methods and a user-friendly interface.

The main features of TexPruner include:

  • Compatibility: TextPruner is compatible with different NLU pre-trained models. You can use it to prune your own models for various NLP tasks as long as they are built on the standard pre-trained models.

  • Usability: TextPruner can be used as a package or a CLI tool. They are both easy to use.

  • Efficiency: TextPruner reduces the model size in a simple and fast way. TextPruner uses structured training-free methods to prune models. It is much faster than distillation and other pruning methods that involve training.

TextPruner currently supports the following pre-trained models in transformers:

  • BERT

  • Albert

  • Electra

  • RoBERTa

  • XLM-RoBERTa

Installation

pip install textpruner

Note

This document is under development.

Configurations

Config

class textpruner.configurations.Config[source]

Base class for GeneralConfig, VocabularyPruningConfig and TransformerPruningConfig.

classmethod from_dict(config_map: dict)[source]

Construct the configuration from a dict.

classmethod from_json(json_filename: str)[source]

Construct the configuration from a json file.

save_to_json(json_filename: str)[source]

Save the configuration the a json file.

GeneralConfig

class textpruner.configurations.GeneralConfig(use_device: str = 'auto', output_dir: str = './pruned_models', config_class: str = 'GeneralConfig')[source]

Configurations for the device and the output directory.

Parameters
  • device'cpu' or 'cuda' or 'cuda:0' etc. Specify which device to use. If it is set to 'auto', TextPruner will try to use the CUDA device if there is one; otherwise uses CPU.

  • output_dir – The diretory to save the pruned models.

  • config_class – Type of the configurations. Users should not change its value.

VocabularyPruningConfig

class textpruner.configurations.VocabularyPruningConfig(min_count: int = 1, prune_lm_head: Union[bool, str] = 'auto', config_class: str = 'VocabularyPruningConfig')[source]

Configurations for vocabulary pruning.

Parameters
  • min_count – The threshold to decide if the token should be removed. The token will be removed from the vocabulary if it appears less than min_count times in the corpus.

  • prune_lm_head – whether pruning the lm_head if the model has one. If prune_lm_head==False, TextPruner will not prune the lm_head; if prune_lm_head==True, TextPruner will prune the lm_head and raise a error if the model does not have an lm_head; if prune_lm_head=='auto', TextPruner will try to prune the lm_head and will continue if the model does not have an lm_head.

  • config_class – Type of the configurations. Users should not change its value.

TransformerPruningConfig

class textpruner.configurations.TransformerPruningConfig(target_ffn_size: Optional[int] = None, target_num_of_heads: Optional[int] = None, pruning_method: str = 'masks', ffn_even_masking: Optional[bool] = True, head_even_masking: Optional[bool] = True, n_iters: Optional[int] = 1, multiple_of: int = 1, pruning_order: Optional[str] = None, use_logits: bool = False, config_class: str = 'TransformerPruningConfig')[source]

Configurations for transformer pruning.

Parameters
  • target_ffn_size – the target average FFN size per layer.

  • target_num_of_heads – the target average number of heads per layer.

  • pruning_method'masks' or 'iterative'. If set to 'masks', the pruner prunes the model with the given masks (head_mask and ffn_mask). If set to 'iterative'. the pruner calculates the importance scores of the neurons based on the data provided by the dataloader and then prunes the model based on the scores.

  • ffn_even_masking – Whether the FFN size of each layer should be the same.

  • head_even_masking – Whether the number of attention heads of each layer should be the same.

  • n_iters – if pruning_method is set to 'iterative', n_iters is number of pruning iterations to prune the model progressively.

  • multiple_of – if ffn_even_masking is False, restrict the target FFN size of each layer to be a multiple of multiple_if.

  • pruning_orderNone or 'head-first' or 'ffn-first'. None: prune the attention heads and ffn layer simultaneously; if set to 'head-first' or 'ffn-first', the actual number of iterations is 2*n_iters.

  • use_logits – if True, performs self-supervised pruning, where the logits are treated as the soft labels.

  • config_class – Type of the configurations. Users should not change its value.

Warning

if ffn_even_masking is False, the pruned model can not be save normally (we cannot load the model with the transformers libarary with the saved weights). So make sure to set save_model=False when calling TransformerPruner.prune() or PipelinePruner.prune(). There are two ways to avoid this:

  • Save the model in TorchScript format manually;

  • Set keep_shape=False when calling TransformerPruner.prune() or PipelinePruner.prune(), so the full model can be saved. Then save the ffn_masks and head_masks. When loading the model, load the full model and then prune it with the masks.

Pruners

VocabularyPruner

class textpruner.VocabularyPruner(model: torch.nn.modules.module.Module, tokenizer, vocabulary_pruning_config: Optional[textpruner.configurations.VocabularyPruningConfig] = None, general_config: Optional[textpruner.configurations.GeneralConfig] = None, base_model_prefix: Optional[str] = None)[source]
Parameters
  • model – The model to be pruned.

  • tokenizer – The tokenizer for the model.

  • vocabulary_pruning_config – a VocabularyPruningConfig object.

  • general_config – a GeneralConfig object.

  • base_model_prefix – The prefix of the base model, i.e., the name of the base model as a member in the model. For example, if model.bert_encoder = BertModel(...), then the base_model_prefix is bert_encoder. TextPruner will infer the base_model_prefix so we can leave its value as None. But if it fails, users have to set its value explicitly.

VocabularyPruner.prune(dataiter=None, additional_tokens=None, additional_token_ids=None, save_model=True) Optional[str][source]

Prunes the vocabulay of the model and the tokenizer. The pruner will only keep the tokens in dataiter, additional_tokens and additional_token_ids.

  • Use dataiter to generate a set of tokens from the raw texts.

  • Use additional_tokens or additional_token_ids to specify the tokens or token_ids directly without running the tokenization.

Parameters
  • dataiter – a list of pre-tokenized strings. These strings will be tokenized by the tokenizer to generate a set of tokens.

  • additional_tokens – a list of tokens. These tokens must be existed in the original vocabulary.

  • additional_token_ids – a list of ints representing the token ids.

  • save_model – whether to save the model when the pruning is finished.

TransformerPruner

class textpruner.TransformerPruner(model: torch.nn.modules.module.Module, transformer_pruning_config: Optional[textpruner.configurations.TransformerPruningConfig] = None, general_config: Optional[textpruner.configurations.GeneralConfig] = None, base_model_prefix: Optional[str] = None)[source]
Parameters
  • model – The model to be pruned.

  • transformer_pruning_config – a TransformerPruningConfig object.

  • general_config – a GeneralConfig object.

  • base_model_prefix – The prefix of the base model, i.e., the name of the base model as a member in the model. For example, if model.bert_encoder = BertModel(...), then the base_model_prefix is bert_encoder. TextPruner will infer the base_model_prefix so we can leave its value as None. But if it fails, users have to set its value explicitly.

TransformerPruner.prune(dataloader=None, adaptor=None, batch_postprocessor=None, head_mask: Optional[torch.Tensor] = None, ffn_mask: Optional[torch.Tensor] = None, keep_shape=False, save_model=True, rewrite_cache=True)[source]

Prunes the transformers. If self.transformer_pruning_config.pruning_method=='masks', the pruner prune the attention heads and the FFN neurons based on the head_masks and ffn_masks; if self.transformer_pruning_config.pruning_method=='iterative', the pruner prune the attention heads and the FFN neurons based on the importance scores calculated on the batches from the dataloader.

Parameters
  • dataloader – a dataloader that generates batches. Each batch should contains both the inputs and the labels.

  • adaptor – a function that takes the model output and return the loss.

  • batch_postprocessor – a function that takes the batch produced by the dataloader and return a batch. It is used for post-processing the batches if needed.

  • head_mask – a tensor of shape (num_layers, num_attention_heads). 1 means to keep, 0 means to prune.

  • ffn_mask – a tensor of shape (num_layers, intermediate_hidden_size). 1 means to keep, 0 means to prune.

  • keep_shape – if True, the model is no actually pruned and the model stucture is not changed, but the weights that should be pruned are set to zero.

  • save_model – whether to save the model when the pruning is finished.

PipelinePruner

class textpruner.PipelinePruner(model: torch.nn.modules.module.Module, tokenizer, transformer_pruning_config: Optional[textpruner.configurations.TransformerPruningConfig] = None, vocabulary_pruning_config: Optional[textpruner.configurations.VocabularyPruningConfig] = None, general_config: Optional[textpruner.configurations.GeneralConfig] = None, base_model_prefix: Optional[str] = None)[source]
Parameters
  • model – The model to be pruned.

  • tokenizer – The tokenizer for the model.

  • vocabulary_pruning_config – a VocabularyPruningConfig object.

  • transformer_pruning_config – a TransformerPruningConfig object.

  • general_config – a GeneralConfig object.

  • base_model_prefix – The prefix of the base model, i.e., the name of the base model as a member in the model. For example, if model.bert_encoder = BertModel(...), then the base_model_prefix is bert_encoder. TextPruner will infer the base_model_prefix so we can leave its value as None. But if it fails, users have to set its value explicitly.

PipelinePruner.prune(dataloader=None, adaptor=None, batch_postprocessor=None, head_mask: Optional[torch.Tensor] = None, ffn_mask: Optional[torch.Tensor] = None, keep_shape=False, dataiter=None, additional_tokens=None, additional_token_ids=None, save_model=True) Optional[str][source]

Prunes the transformers, then prunes the vocabulary.

Parameters
  • dataloader – a dataloader that generates batches. Each batch should contains both the inputs and the labels.

  • adaptor – a function that takes the model output and return the loss.

  • batch_postprocessor – a function that takes the batch produced by the dataloader and return a batch. It is used for post-processing the batches if needed.

  • head_mask – a tensor of shape (num_layers, num_attention_heads). 1 means to keep, 0 means to prune.

  • ffn_mask – a tensor of shape (num_layers, intermediate_hidden_size). 1 means to keep, 0 means to prune.

  • keep_shape – if True, the model is no actually pruned and the model stucture is not changed, but the weights that should be pruned are set to zero.

  • dataiter – a list of pre-tokenized strings. These strings will be tokenized by the tokenizer to generate a set of tokens.

  • additional_tokens – a list of tokens. These tokens must be existed in the original vocabulary.

  • additional_token_ids – a list of ints representing the token ids.

  • save_model – whether to save the model when the pruning is finished.

Utils

summary

textpruner.summary(model: Union[torch.nn.modules.module.Module, Dict], max_level: Optional[int] = 2)[source]

Show the summary of model parameters.

Parameters
  • model – the model to be inspected, can be a torch module or a state_dict.

  • max_level – The max level to display. If max_level==None, show all the levels.

Returns

A formatted string.

Example:

print(textpruner.summay(model))

inference_time

textpruner.inference_time(model: torch.nn.modules.module.Module, dummy_inputs: Union[List, Tuple, Dict], warm_up: int = 5, repetitions: int = 10)[source]

Measure and print the inference time of the model.

Parameters
  • model – the torch module to be measured.

  • dummpy_inputs – the inputs to be fed into the model, can be a list ,tuple or dict.

  • warm_up – Number of steps to warm up the device.

  • repetitions – Number of steps to perform forward propagation. More repetitions result in more accurate measurements.

Example:

input_ids = torch.randint(low=0,high=10000,size=(32,256))
textpruner.inference_time(model,dummy_inputs=[input_ids])