Welcome to TextPruner’s documentation

TextPruner is a toolkit for pruning pre-trained transformer-based language models written in PyTorch. It offers structured training-free pruning methods and a user-friendly interface.
The main features of TexPruner include:
Compatibility: TextPruner is compatible with different NLU pre-trained models. You can use it to prune your own models for various NLP tasks as long as they are built on the standard pre-trained models.
Usability: TextPruner can be used as a package or a CLI tool. They are both easy to use.
Efficiency: TextPruner reduces the model size in a simple and fast way. TextPruner uses structured training-free methods to prune models. It is much faster than distillation and other pruning methods that involve training.
TextPruner currently supports the following pre-trained models in transformers:
BERT
Albert
Electra
RoBERTa
XLM-RoBERTa
Installation
pip install textpruner
Note
This document is under development.
Configurations
Config
- class textpruner.configurations.Config[source]
Base class for
GeneralConfig
,VocabularyPruningConfig
andTransformerPruningConfig
.
GeneralConfig
- class textpruner.configurations.GeneralConfig(use_device: str = 'auto', output_dir: str = './pruned_models', config_class: str = 'GeneralConfig')[source]
Configurations for the device and the output directory.
- Parameters
device –
'cpu'
or'cuda'
or'cuda:0'
etc. Specify which device to use. If it is set to'auto'
, TextPruner will try to use the CUDA device if there is one; otherwise uses CPU.output_dir – The diretory to save the pruned models.
config_class – Type of the configurations. Users should not change its value.
VocabularyPruningConfig
- class textpruner.configurations.VocabularyPruningConfig(min_count: int = 1, prune_lm_head: Union[bool, str] = 'auto', config_class: str = 'VocabularyPruningConfig')[source]
Configurations for vocabulary pruning.
- Parameters
min_count – The threshold to decide if the token should be removed. The token will be removed from the vocabulary if it appears less than
min_count
times in the corpus.prune_lm_head – whether pruning the lm_head if the model has one. If
prune_lm_head==False
, TextPruner will not prune the lm_head; ifprune_lm_head==True
, TextPruner will prune the lm_head and raise a error if the model does not have an lm_head; ifprune_lm_head=='auto'
, TextPruner will try to prune the lm_head and will continue if the model does not have an lm_head.config_class – Type of the configurations. Users should not change its value.
TransformerPruningConfig
- class textpruner.configurations.TransformerPruningConfig(target_ffn_size: Optional[int] = None, target_num_of_heads: Optional[int] = None, pruning_method: str = 'masks', ffn_even_masking: Optional[bool] = True, head_even_masking: Optional[bool] = True, n_iters: Optional[int] = 1, multiple_of: int = 1, pruning_order: Optional[str] = None, use_logits: bool = False, config_class: str = 'TransformerPruningConfig')[source]
Configurations for transformer pruning.
- Parameters
target_ffn_size – the target average FFN size per layer.
target_num_of_heads – the target average number of heads per layer.
pruning_method –
'masks'
or'iterative'
. If set to'masks'
, the pruner prunes the model with the given masks (head_mask
andffn_mask
). If set to'iterative'
. the pruner calculates the importance scores of the neurons based on the data provided by thedataloader
and then prunes the model based on the scores.ffn_even_masking – Whether the FFN size of each layer should be the same.
head_even_masking – Whether the number of attention heads of each layer should be the same.
n_iters – if
pruning_method
is set to'iterative'
,n_iters
is number of pruning iterations to prune the model progressively.multiple_of – if
ffn_even_masking
isFalse
, restrict the target FFN size of each layer to be a multiple ofmultiple_if
.pruning_order –
None
or'head-first'
or'ffn-first'
.None
: prune the attention heads and ffn layer simultaneously; if set to'head-first'
or'ffn-first'
, the actual number of iterations is2*n_iters
.use_logits – if
True
, performs self-supervised pruning, where the logits are treated as the soft labels.config_class – Type of the configurations. Users should not change its value.
Warning
if
ffn_even_masking
isFalse
, the pruned model can not be save normally (we cannot load the model with the transformers libarary with the saved weights). So make sure to setsave_model=False
when callingTransformerPruner.prune()
orPipelinePruner.prune()
. There are two ways to avoid this:Save the model in TorchScript format manually;
Set
keep_shape=False
when callingTransformerPruner.prune()
orPipelinePruner.prune()
, so the full model can be saved. Then save theffn_masks
andhead_masks
. When loading the model, load the full model and then prune it with the masks.
Pruners
VocabularyPruner
- class textpruner.VocabularyPruner(model: torch.nn.modules.module.Module, tokenizer, vocabulary_pruning_config: Optional[textpruner.configurations.VocabularyPruningConfig] = None, general_config: Optional[textpruner.configurations.GeneralConfig] = None, base_model_prefix: Optional[str] = None)[source]
- Parameters
model – The model to be pruned.
tokenizer – The tokenizer for the model.
vocabulary_pruning_config – a
VocabularyPruningConfig
object.general_config – a
GeneralConfig
object.base_model_prefix – The prefix of the base model, i.e., the name of the base model as a member in the model. For example, if
model.bert_encoder = BertModel(...)
, then thebase_model_prefix
isbert_encoder
. TextPruner will infer thebase_model_prefix
so we can leave its value asNone
. But if it fails, users have to set its value explicitly.
- VocabularyPruner.prune(dataiter=None, additional_tokens=None, additional_token_ids=None, save_model=True) Optional[str] [source]
Prunes the vocabulay of the model and the tokenizer. The pruner will only keep the tokens in
dataiter
,additional_tokens
andadditional_token_ids
.Use
dataiter
to generate a set of tokens from the raw texts.Use
additional_tokens
oradditional_token_ids
to specify the tokens or token_ids directly without running the tokenization.
- Parameters
dataiter – a list of pre-tokenized strings. These strings will be tokenized by the tokenizer to generate a set of tokens.
additional_tokens – a list of tokens. These tokens must be existed in the original vocabulary.
additional_token_ids – a list of ints representing the token ids.
save_model – whether to save the model when the pruning is finished.
TransformerPruner
- class textpruner.TransformerPruner(model: torch.nn.modules.module.Module, transformer_pruning_config: Optional[textpruner.configurations.TransformerPruningConfig] = None, general_config: Optional[textpruner.configurations.GeneralConfig] = None, base_model_prefix: Optional[str] = None)[source]
- Parameters
model – The model to be pruned.
transformer_pruning_config – a
TransformerPruningConfig
object.general_config – a
GeneralConfig
object.base_model_prefix – The prefix of the base model, i.e., the name of the base model as a member in the model. For example, if
model.bert_encoder = BertModel(...)
, then thebase_model_prefix
isbert_encoder
. TextPruner will infer thebase_model_prefix
so we can leave its value asNone
. But if it fails, users have to set its value explicitly.
- TransformerPruner.prune(dataloader=None, adaptor=None, batch_postprocessor=None, head_mask: Optional[torch.Tensor] = None, ffn_mask: Optional[torch.Tensor] = None, keep_shape=False, save_model=True, rewrite_cache=True)[source]
Prunes the transformers. If
self.transformer_pruning_config.pruning_method=='masks'
, the pruner prune the attention heads and the FFN neurons based on thehead_masks
andffn_masks
; ifself.transformer_pruning_config.pruning_method=='iterative'
, the pruner prune the attention heads and the FFN neurons based on the importance scores calculated on the batches from thedataloader
.- Parameters
dataloader – a dataloader that generates batches. Each batch should contains both the inputs and the labels.
adaptor – a function that takes the model output and return the loss.
batch_postprocessor – a function that takes the batch produced by the dataloader and return a batch. It is used for post-processing the batches if needed.
head_mask – a tensor of shape
(num_layers, num_attention_heads)
. 1 means to keep, 0 means to prune.ffn_mask – a tensor of shape
(num_layers, intermediate_hidden_size)
. 1 means to keep, 0 means to prune.keep_shape – if
True
, the model is no actually pruned and the model stucture is not changed, but the weights that should be pruned are set to zero.save_model – whether to save the model when the pruning is finished.
PipelinePruner
- class textpruner.PipelinePruner(model: torch.nn.modules.module.Module, tokenizer, transformer_pruning_config: Optional[textpruner.configurations.TransformerPruningConfig] = None, vocabulary_pruning_config: Optional[textpruner.configurations.VocabularyPruningConfig] = None, general_config: Optional[textpruner.configurations.GeneralConfig] = None, base_model_prefix: Optional[str] = None)[source]
- Parameters
model – The model to be pruned.
tokenizer – The tokenizer for the model.
vocabulary_pruning_config – a
VocabularyPruningConfig
object.transformer_pruning_config – a
TransformerPruningConfig
object.general_config – a
GeneralConfig
object.base_model_prefix – The prefix of the base model, i.e., the name of the base model as a member in the model. For example, if
model.bert_encoder = BertModel(...)
, then thebase_model_prefix
isbert_encoder
. TextPruner will infer thebase_model_prefix
so we can leave its value asNone
. But if it fails, users have to set its value explicitly.
- PipelinePruner.prune(dataloader=None, adaptor=None, batch_postprocessor=None, head_mask: Optional[torch.Tensor] = None, ffn_mask: Optional[torch.Tensor] = None, keep_shape=False, dataiter=None, additional_tokens=None, additional_token_ids=None, save_model=True) Optional[str] [source]
Prunes the transformers, then prunes the vocabulary.
- Parameters
dataloader – a dataloader that generates batches. Each batch should contains both the inputs and the labels.
adaptor – a function that takes the model output and return the loss.
batch_postprocessor – a function that takes the batch produced by the dataloader and return a batch. It is used for post-processing the batches if needed.
head_mask – a tensor of shape
(num_layers, num_attention_heads)
. 1 means to keep, 0 means to prune.ffn_mask – a tensor of shape
(num_layers, intermediate_hidden_size)
. 1 means to keep, 0 means to prune.keep_shape – if
True
, the model is no actually pruned and the model stucture is not changed, but the weights that should be pruned are set to zero.dataiter – a list of pre-tokenized strings. These strings will be tokenized by the tokenizer to generate a set of tokens.
additional_tokens – a list of tokens. These tokens must be existed in the original vocabulary.
additional_token_ids – a list of ints representing the token ids.
save_model – whether to save the model when the pruning is finished.
Utils
summary
- textpruner.summary(model: Union[torch.nn.modules.module.Module, Dict], max_level: Optional[int] = 2)[source]
Show the summary of model parameters.
- Parameters
model – the model to be inspected, can be a torch module or a state_dict.
max_level – The max level to display. If
max_level==None
, show all the levels.
- Returns
A formatted string.
Example:
print(textpruner.summay(model))
inference_time
- textpruner.inference_time(model: torch.nn.modules.module.Module, dummy_inputs: Union[List, Tuple, Dict], warm_up: int = 5, repetitions: int = 10)[source]
Measure and print the inference time of the model.
- Parameters
model – the torch module to be measured.
dummpy_inputs – the inputs to be fed into the model, can be a list ,tuple or dict.
warm_up – Number of steps to warm up the device.
repetitions – Number of steps to perform forward propagation. More repetitions result in more accurate measurements.
Example:
input_ids = torch.randint(low=0,high=10000,size=(32,256)) textpruner.inference_time(model,dummy_inputs=[input_ids])