Pruners
VocabularyPruner
- class textpruner.VocabularyPruner(model: torch.nn.modules.module.Module, tokenizer, vocabulary_pruning_config: Optional[textpruner.configurations.VocabularyPruningConfig] = None, general_config: Optional[textpruner.configurations.GeneralConfig] = None, base_model_prefix: Optional[str] = None)[source]
- Parameters
model – The model to be pruned.
tokenizer – The tokenizer for the model.
vocabulary_pruning_config – a
VocabularyPruningConfigobject.general_config – a
GeneralConfigobject.base_model_prefix – The prefix of the base model, i.e., the name of the base model as a member in the model. For example, if
model.bert_encoder = BertModel(...), then thebase_model_prefixisbert_encoder. TextPruner will infer thebase_model_prefixso we can leave its value asNone. But if it fails, users have to set its value explicitly.
- VocabularyPruner.prune(dataiter=None, additional_tokens=None, additional_token_ids=None, save_model=True) Optional[str][source]
Prunes the vocabulay of the model and the tokenizer. The pruner will only keep the tokens in
dataiter,additional_tokensandadditional_token_ids.Use
dataiterto generate a set of tokens from the raw texts.Use
additional_tokensoradditional_token_idsto specify the tokens or token_ids directly without running the tokenization.
- Parameters
dataiter – a list of pre-tokenized strings. These strings will be tokenized by the tokenizer to generate a set of tokens.
additional_tokens – a list of tokens. These tokens must be existed in the original vocabulary.
additional_token_ids – a list of ints representing the token ids.
save_model – whether to save the model when the pruning is finished.
TransformerPruner
- class textpruner.TransformerPruner(model: torch.nn.modules.module.Module, transformer_pruning_config: Optional[textpruner.configurations.TransformerPruningConfig] = None, general_config: Optional[textpruner.configurations.GeneralConfig] = None, base_model_prefix: Optional[str] = None)[source]
- Parameters
model – The model to be pruned.
transformer_pruning_config – a
TransformerPruningConfigobject.general_config – a
GeneralConfigobject.base_model_prefix – The prefix of the base model, i.e., the name of the base model as a member in the model. For example, if
model.bert_encoder = BertModel(...), then thebase_model_prefixisbert_encoder. TextPruner will infer thebase_model_prefixso we can leave its value asNone. But if it fails, users have to set its value explicitly.
- TransformerPruner.prune(dataloader=None, adaptor=None, batch_postprocessor=None, head_mask: Optional[torch.Tensor] = None, ffn_mask: Optional[torch.Tensor] = None, keep_shape=False, save_model=True, rewrite_cache=True)[source]
Prunes the transformers. If
self.transformer_pruning_config.pruning_method=='masks', the pruner prune the attention heads and the FFN neurons based on thehead_masksandffn_masks; ifself.transformer_pruning_config.pruning_method=='iterative', the pruner prune the attention heads and the FFN neurons based on the importance scores calculated on the batches from thedataloader.- Parameters
dataloader – a dataloader that generates batches. Each batch should contains both the inputs and the labels.
adaptor – a function that takes the model output and return the loss.
batch_postprocessor – a function that takes the batch produced by the dataloader and return a batch. It is used for post-processing the batches if needed.
head_mask – a tensor of shape
(num_layers, num_attention_heads). 1 means to keep, 0 means to prune.ffn_mask – a tensor of shape
(num_layers, intermediate_hidden_size). 1 means to keep, 0 means to prune.keep_shape – if
True, the model is no actually pruned and the model stucture is not changed, but the weights that should be pruned are set to zero.save_model – whether to save the model when the pruning is finished.
PipelinePruner
- class textpruner.PipelinePruner(model: torch.nn.modules.module.Module, tokenizer, transformer_pruning_config: Optional[textpruner.configurations.TransformerPruningConfig] = None, vocabulary_pruning_config: Optional[textpruner.configurations.VocabularyPruningConfig] = None, general_config: Optional[textpruner.configurations.GeneralConfig] = None, base_model_prefix: Optional[str] = None)[source]
- Parameters
model – The model to be pruned.
tokenizer – The tokenizer for the model.
vocabulary_pruning_config – a
VocabularyPruningConfigobject.transformer_pruning_config – a
TransformerPruningConfigobject.general_config – a
GeneralConfigobject.base_model_prefix – The prefix of the base model, i.e., the name of the base model as a member in the model. For example, if
model.bert_encoder = BertModel(...), then thebase_model_prefixisbert_encoder. TextPruner will infer thebase_model_prefixso we can leave its value asNone. But if it fails, users have to set its value explicitly.
- PipelinePruner.prune(dataloader=None, adaptor=None, batch_postprocessor=None, head_mask: Optional[torch.Tensor] = None, ffn_mask: Optional[torch.Tensor] = None, keep_shape=False, dataiter=None, additional_tokens=None, additional_token_ids=None, save_model=True) Optional[str][source]
Prunes the transformers, then prunes the vocabulary.
- Parameters
dataloader – a dataloader that generates batches. Each batch should contains both the inputs and the labels.
adaptor – a function that takes the model output and return the loss.
batch_postprocessor – a function that takes the batch produced by the dataloader and return a batch. It is used for post-processing the batches if needed.
head_mask – a tensor of shape
(num_layers, num_attention_heads). 1 means to keep, 0 means to prune.ffn_mask – a tensor of shape
(num_layers, intermediate_hidden_size). 1 means to keep, 0 means to prune.keep_shape – if
True, the model is no actually pruned and the model stucture is not changed, but the weights that should be pruned are set to zero.dataiter – a list of pre-tokenized strings. These strings will be tokenized by the tokenizer to generate a set of tokens.
additional_tokens – a list of tokens. These tokens must be existed in the original vocabulary.
additional_token_ids – a list of ints representing the token ids.
save_model – whether to save the model when the pruning is finished.