import torch
from collections.abc import Mapping
from tqdm import tqdm
import time
from typing import Tuple, Union,Dict,Optional,List
class LayerNode:
def __init__(self,name,parent=None,value=None,fullname=None):
self.name = name
self.fullname = fullname
self.value = None
self.children_name = {}
self.parent = parent
def __contains__(self, key):
return key in self.children_name
def __getitem__(self,key):
return self.children_name[key]
def __setitem__(self,key,value):
self.children_name[key]=value
def update(self,value):
if self.parent:
if self.parent.value is None:
self.parent.value = value
else:
if isinstance(value,(tuple,list)):
old_value = self.parent.value
new_value = [old_value[i]+value[i] for i in range(len(value))]
self.parent.value = new_value
else:
self.parent.value += value
if self.name.endswith('(shared)'):
if self.parent.name.endswith('shared)'):
pass
elif self.parent.value[0] == 0:
self.parent.name += '(shared)'
else:
self.parent.name += '(partially shared)'
self.parent.update(value)
def format(self, level=0, total=None ,indent='--',max_level=None,max_length=None):
string =''
if total is None:
total = self.value[0]
if level ==0:
max_length = self._max_name_length(indent,' ',max_level=max_level) + 1
string += '\n'
string +=f"{'LAYER NAME':<{max_length}}\t{'#PARAMS':>15}\t{'RATIO':>10}\t{'MEM(MB)':>8}\n"
if max_level is not None and level==max_level:
string += f"{indent+self.name+':':<{max_length}}\t{self.value[0]:15,d}\t{self.value[0]/total:>10.2%}\t{self.value[1]:>8.2f}\n"
else:
if len(self.children_name)==1:
string += f"{indent+self.name:{max_length}}\n"
else:
string += f"{indent+self.name+':':<{max_length}}\t{self.value[0]:15,d}\t{self.value[0]/total:>10.2%}\t{self.value[1]:>8.2f}\n"
for child_name, child in self.children_name.items():
string += child.format(level+1, total,
indent=' '+indent, max_level=max_level,max_length=max_length)
return string
def _max_name_length(self,indent1='--', indent2=' ',level=0,max_level=None):
length = len(self.name) + len(indent1) + level *len(indent2)
if max_level is not None and level >= max_level:
child_lengths = []
else:
child_lengths = [child._max_name_length(indent1,indent2,level=level+1,max_level=max_level)
for child in self.children_name.values()]
max_length = max(child_lengths+[length])
return max_length
[docs]def summary(model : Union[torch.nn.Module,Dict], max_level : Optional[int] = 2):
"""
Show the summary of model parameters.
Args:
model: the model to be inspected, can be a torch module or a state_dict.
max_level: The max level to display. If ``max_level==None``, show all the levels.
Returns:
A formatted string.
Example::
print(textpruner.summay(model))
"""
if isinstance(model,torch.nn.Module):
state_dict = model.state_dict()
elif isinstance(model,dict):
state_dict = model
else:
raise TypeError("model should be either torch.nn.Module or a dict")
hash_set = set()
model_node = LayerNode('model',fullname='model')
current = model_node
for key,value in state_dict.items():
names = key.split('.')
for i,name in enumerate(names):
if name not in current:
current[name] = LayerNode(name,parent=current,fullname='.'.join(names[:i+1]))
current = current[name]
if (value.data_ptr()) in hash_set:
current.value = [0,0]
current.name += "(shared)"
current.fullname += "(shared)"
current.update(current.value)
else:
hash_set.add(value.data_ptr())
current.value = [value.numel(),value.numel() * value.element_size() / 1024 / 1024]
current.update(current.value)
current = model_node
result = model_node.format(max_level=max_level)
return result
[docs]def inference_time(model : torch.nn.Module, dummy_inputs : Union[List,Tuple,Dict], warm_up : int = 5, repetitions : int = 10):
"""
Measure and print the inference time of the model.
Args:
model: the torch module to be measured.
dummpy_inputs: the inputs to be fed into the model, can be a list ,tuple or dict.
warm_up: Number of steps to warm up the device.
repetitions: Number of steps to perform forward propagation. More repetitions result in more accurate measurements.
Example::
input_ids = torch.randint(low=0,high=10000,size=(32,256))
textpruner.inference_time(model,dummy_inputs=[input_ids])
"""
device = model.device
is_train = model.training
model.eval()
if device.type == 'cpu':
mean, std = cpu_inference_time(model, dummy_inputs, warm_up, repetitions)
elif device.type == 'cuda':
mean, std = cuda_inference_time(model, dummy_inputs, warm_up, repetitions)
else:
raise ValueError(f"Unknown device {device}")
model.train(is_train)
print(f"Device: {device}")
print(f"Mean inference time: {mean:.2f}ms")
print(f"Standard deviation: {std:.2f}ms")
return mean, std
def cuda_inference_time(model : torch.nn.Module, dummy_inputs, warm_up, repetitions):
device = model.device
starter = torch.cuda.Event(enable_timing=True)
ender = torch.cuda.Event(enable_timing=True)
timings=torch.zeros(repetitions)
with torch.no_grad():
for _ in tqdm(range(warm_up),desc='cuda-warm-up'):
if isinstance(dummy_inputs, Mapping):
inputs = {k: v.to(device) for k,v in dummy_inputs.items()}
_ = model(**inputs)
else:
inputs = [t.to(device) for t in dummy_inputs]
_ = model(*inputs)
for rep in tqdm(range(repetitions),desc='cuda-repetitions'):
if isinstance(dummy_inputs, Mapping):
inputs = {k: v.to(device) for k,v in dummy_inputs.items()}
starter.record()
_ = model(**inputs)
ender.record()
else:
inputs = [t.to(device) for t in dummy_inputs]
starter.record()
_ = model(*inputs)
ender.record()
torch.cuda.synchronize()
elapsed_time_ms = starter.elapsed_time(ender)
timings[rep] = elapsed_time_ms
mean = timings.sum().item() / repetitions
std = timings.std().item()
return mean, std
def cpu_inference_time(model : torch.nn.Module, dummy_inputs, warm_up, repetitions):
device = model.device
timings=torch.zeros(repetitions)
with torch.no_grad():
for _ in tqdm(range(warm_up),desc='cpu-warm-up'):
if isinstance(dummy_inputs, Mapping):
inputs = {k: v.to(device) for k,v in dummy_inputs.items()}
_ = model(**inputs)
else:
inputs = [t.to(device) for t in dummy_inputs]
_ = model(*inputs)
for rep in tqdm(range(repetitions),desc='cpu-repetitions'):
if isinstance(dummy_inputs, Mapping):
inputs = {k: v.to(device) for k,v in dummy_inputs.items()}
start = time.time()
_ = model(**inputs)
end = time.time()
else:
inputs = [t.to(device) for t in dummy_inputs]
start = time.time()
_ = model(*inputs)
end = time.time()
elapsed_time_ms = (end - start) * 1000
timings[rep] = elapsed_time_ms
mean = timings.sum().item() / repetitions
std = timings.std().item()
return mean, std