Pruners

VocabularyPruner

class textpruner.VocabularyPruner(model: torch.nn.modules.module.Module, tokenizer, vocabulary_pruning_config: Optional[textpruner.configurations.VocabularyPruningConfig] = None, general_config: Optional[textpruner.configurations.GeneralConfig] = None, base_model_prefix: Optional[str] = None)[source]
Parameters
  • model – The model to be pruned.

  • tokenizer – The tokenizer for the model.

  • vocabulary_pruning_config – a VocabularyPruningConfig object.

  • general_config – a GeneralConfig object.

  • base_model_prefix – The prefix of the base model, i.e., the name of the base model as a member in the model. For example, if model.bert_encoder = BertModel(...), then the base_model_prefix is bert_encoder. TextPruner will infer the base_model_prefix so we can leave its value as None. But if it fails, users have to set its value explicitly.

VocabularyPruner.prune(dataiter=None, additional_tokens=None, additional_token_ids=None, save_model=True) Optional[str][source]

Prunes the vocabulay of the model and the tokenizer. The pruner will only keep the tokens in dataiter, additional_tokens and additional_token_ids.

  • Use dataiter to generate a set of tokens from the raw texts.

  • Use additional_tokens or additional_token_ids to specify the tokens or token_ids directly without running the tokenization.

Parameters
  • dataiter – a list of pre-tokenized strings. These strings will be tokenized by the tokenizer to generate a set of tokens.

  • additional_tokens – a list of tokens. These tokens must be existed in the original vocabulary.

  • additional_token_ids – a list of ints representing the token ids.

  • save_model – whether to save the model when the pruning is finished.

TransformerPruner

class textpruner.TransformerPruner(model: torch.nn.modules.module.Module, transformer_pruning_config: Optional[textpruner.configurations.TransformerPruningConfig] = None, general_config: Optional[textpruner.configurations.GeneralConfig] = None, base_model_prefix: Optional[str] = None)[source]
Parameters
  • model – The model to be pruned.

  • transformer_pruning_config – a TransformerPruningConfig object.

  • general_config – a GeneralConfig object.

  • base_model_prefix – The prefix of the base model, i.e., the name of the base model as a member in the model. For example, if model.bert_encoder = BertModel(...), then the base_model_prefix is bert_encoder. TextPruner will infer the base_model_prefix so we can leave its value as None. But if it fails, users have to set its value explicitly.

TransformerPruner.prune(dataloader=None, adaptor=None, batch_postprocessor=None, head_mask: Optional[torch.Tensor] = None, ffn_mask: Optional[torch.Tensor] = None, keep_shape=False, save_model=True, rewrite_cache=True)[source]

Prunes the transformers. If self.transformer_pruning_config.pruning_method=='masks', the pruner prune the attention heads and the FFN neurons based on the head_masks and ffn_masks; if self.transformer_pruning_config.pruning_method=='iterative', the pruner prune the attention heads and the FFN neurons based on the importance scores calculated on the batches from the dataloader.

Parameters
  • dataloader – a dataloader that generates batches. Each batch should contains both the inputs and the labels.

  • adaptor – a function that takes the model output and return the loss.

  • batch_postprocessor – a function that takes the batch produced by the dataloader and return a batch. It is used for post-processing the batches if needed.

  • head_mask – a tensor of shape (num_layers, num_attention_heads). 1 means to keep, 0 means to prune.

  • ffn_mask – a tensor of shape (num_layers, intermediate_hidden_size). 1 means to keep, 0 means to prune.

  • keep_shape – if True, the model is no actually pruned and the model stucture is not changed, but the weights that should be pruned are set to zero.

  • save_model – whether to save the model when the pruning is finished.

PipelinePruner

class textpruner.PipelinePruner(model: torch.nn.modules.module.Module, tokenizer, transformer_pruning_config: Optional[textpruner.configurations.TransformerPruningConfig] = None, vocabulary_pruning_config: Optional[textpruner.configurations.VocabularyPruningConfig] = None, general_config: Optional[textpruner.configurations.GeneralConfig] = None, base_model_prefix: Optional[str] = None)[source]
Parameters
  • model – The model to be pruned.

  • tokenizer – The tokenizer for the model.

  • vocabulary_pruning_config – a VocabularyPruningConfig object.

  • transformer_pruning_config – a TransformerPruningConfig object.

  • general_config – a GeneralConfig object.

  • base_model_prefix – The prefix of the base model, i.e., the name of the base model as a member in the model. For example, if model.bert_encoder = BertModel(...), then the base_model_prefix is bert_encoder. TextPruner will infer the base_model_prefix so we can leave its value as None. But if it fails, users have to set its value explicitly.

PipelinePruner.prune(dataloader=None, adaptor=None, batch_postprocessor=None, head_mask: Optional[torch.Tensor] = None, ffn_mask: Optional[torch.Tensor] = None, keep_shape=False, dataiter=None, additional_tokens=None, additional_token_ids=None, save_model=True) Optional[str][source]

Prunes the transformers, then prunes the vocabulary.

Parameters
  • dataloader – a dataloader that generates batches. Each batch should contains both the inputs and the labels.

  • adaptor – a function that takes the model output and return the loss.

  • batch_postprocessor – a function that takes the batch produced by the dataloader and return a batch. It is used for post-processing the batches if needed.

  • head_mask – a tensor of shape (num_layers, num_attention_heads). 1 means to keep, 0 means to prune.

  • ffn_mask – a tensor of shape (num_layers, intermediate_hidden_size). 1 means to keep, 0 means to prune.

  • keep_shape – if True, the model is no actually pruned and the model stucture is not changed, but the weights that should be pruned are set to zero.

  • dataiter – a list of pre-tokenized strings. These strings will be tokenized by the tokenizer to generate a set of tokens.

  • additional_tokens – a list of tokens. These tokens must be existed in the original vocabulary.

  • additional_token_ids – a list of ints representing the token ids.

  • save_model – whether to save the model when the pruning is finished.