Pruners
VocabularyPruner
- class textpruner.VocabularyPruner(model: torch.nn.modules.module.Module, tokenizer, vocabulary_pruning_config: Optional[textpruner.configurations.VocabularyPruningConfig] = None, general_config: Optional[textpruner.configurations.GeneralConfig] = None, base_model_prefix: Optional[str] = None)[source]
- Parameters
model – The model to be pruned.
tokenizer – The tokenizer for the model.
vocabulary_pruning_config – a
VocabularyPruningConfig
object.general_config – a
GeneralConfig
object.base_model_prefix – The prefix of the base model, i.e., the name of the base model as a member in the model. For example, if
model.bert_encoder = BertModel(...)
, then thebase_model_prefix
isbert_encoder
. TextPruner will infer thebase_model_prefix
so we can leave its value asNone
. But if it fails, users have to set its value explicitly.
- VocabularyPruner.prune(dataiter=None, additional_tokens=None, additional_token_ids=None, save_model=True) Optional[str] [source]
Prunes the vocabulay of the model and the tokenizer. The pruner will only keep the tokens in
dataiter
,additional_tokens
andadditional_token_ids
.Use
dataiter
to generate a set of tokens from the raw texts.Use
additional_tokens
oradditional_token_ids
to specify the tokens or token_ids directly without running the tokenization.
- Parameters
dataiter – a list of pre-tokenized strings. These strings will be tokenized by the tokenizer to generate a set of tokens.
additional_tokens – a list of tokens. These tokens must be existed in the original vocabulary.
additional_token_ids – a list of ints representing the token ids.
save_model – whether to save the model when the pruning is finished.
TransformerPruner
- class textpruner.TransformerPruner(model: torch.nn.modules.module.Module, transformer_pruning_config: Optional[textpruner.configurations.TransformerPruningConfig] = None, general_config: Optional[textpruner.configurations.GeneralConfig] = None, base_model_prefix: Optional[str] = None)[source]
- Parameters
model – The model to be pruned.
transformer_pruning_config – a
TransformerPruningConfig
object.general_config – a
GeneralConfig
object.base_model_prefix – The prefix of the base model, i.e., the name of the base model as a member in the model. For example, if
model.bert_encoder = BertModel(...)
, then thebase_model_prefix
isbert_encoder
. TextPruner will infer thebase_model_prefix
so we can leave its value asNone
. But if it fails, users have to set its value explicitly.
- TransformerPruner.prune(dataloader=None, adaptor=None, batch_postprocessor=None, head_mask: Optional[torch.Tensor] = None, ffn_mask: Optional[torch.Tensor] = None, keep_shape=False, save_model=True, rewrite_cache=True)[source]
Prunes the transformers. If
self.transformer_pruning_config.pruning_method=='masks'
, the pruner prune the attention heads and the FFN neurons based on thehead_masks
andffn_masks
; ifself.transformer_pruning_config.pruning_method=='iterative'
, the pruner prune the attention heads and the FFN neurons based on the importance scores calculated on the batches from thedataloader
.- Parameters
dataloader – a dataloader that generates batches. Each batch should contains both the inputs and the labels.
adaptor – a function that takes the model output and return the loss.
batch_postprocessor – a function that takes the batch produced by the dataloader and return a batch. It is used for post-processing the batches if needed.
head_mask – a tensor of shape
(num_layers, num_attention_heads)
. 1 means to keep, 0 means to prune.ffn_mask – a tensor of shape
(num_layers, intermediate_hidden_size)
. 1 means to keep, 0 means to prune.keep_shape – if
True
, the model is no actually pruned and the model stucture is not changed, but the weights that should be pruned are set to zero.save_model – whether to save the model when the pruning is finished.
PipelinePruner
- class textpruner.PipelinePruner(model: torch.nn.modules.module.Module, tokenizer, transformer_pruning_config: Optional[textpruner.configurations.TransformerPruningConfig] = None, vocabulary_pruning_config: Optional[textpruner.configurations.VocabularyPruningConfig] = None, general_config: Optional[textpruner.configurations.GeneralConfig] = None, base_model_prefix: Optional[str] = None)[source]
- Parameters
model – The model to be pruned.
tokenizer – The tokenizer for the model.
vocabulary_pruning_config – a
VocabularyPruningConfig
object.transformer_pruning_config – a
TransformerPruningConfig
object.general_config – a
GeneralConfig
object.base_model_prefix – The prefix of the base model, i.e., the name of the base model as a member in the model. For example, if
model.bert_encoder = BertModel(...)
, then thebase_model_prefix
isbert_encoder
. TextPruner will infer thebase_model_prefix
so we can leave its value asNone
. But if it fails, users have to set its value explicitly.
- PipelinePruner.prune(dataloader=None, adaptor=None, batch_postprocessor=None, head_mask: Optional[torch.Tensor] = None, ffn_mask: Optional[torch.Tensor] = None, keep_shape=False, dataiter=None, additional_tokens=None, additional_token_ids=None, save_model=True) Optional[str] [source]
Prunes the transformers, then prunes the vocabulary.
- Parameters
dataloader – a dataloader that generates batches. Each batch should contains both the inputs and the labels.
adaptor – a function that takes the model output and return the loss.
batch_postprocessor – a function that takes the batch produced by the dataloader and return a batch. It is used for post-processing the batches if needed.
head_mask – a tensor of shape
(num_layers, num_attention_heads)
. 1 means to keep, 0 means to prune.ffn_mask – a tensor of shape
(num_layers, intermediate_hidden_size)
. 1 means to keep, 0 means to prune.keep_shape – if
True
, the model is no actually pruned and the model stucture is not changed, but the weights that should be pruned are set to zero.dataiter – a list of pre-tokenized strings. These strings will be tokenized by the tokenizer to generate a set of tokens.
additional_tokens – a list of tokens. These tokens must be existed in the original vocabulary.
additional_token_ids – a list of ints representing the token ids.
save_model – whether to save the model when the pruning is finished.