import itertools
import multiprocessing as mp
import torch
import torch.functional as F
from torch.nn import modules
from torch.nn.modules import module
import memtorch
from memtorch.map.Input import naive_scale
from memtorch.map.Parameter import naive_map
from .Conv1d import Conv1d
from .Conv2d import Conv2d
from .Conv3d import Conv3d
from .RNN import RNN
from .Linear import Linear
supported_module_parameters = {
"<class 'torch.nn.modules.linear.Linear'>": Linear,
"<class 'torch.nn.modules.conv.Conv1d'>": Conv1d,
"<class 'torch.nn.modules.conv.Conv2d'>": Conv2d,
"<class 'torch.nn.modules.conv.Conv3d'>": Conv3d,
"<class 'torch.nn.modules.rnn.RNN'>": RNN,
}
[docs]def patch_model(
model,
memristor_model,
memristor_model_params,
module_parameters_to_patch={},
mapping_routine=naive_map,
transistor=True,
programming_routine=None,
programming_routine_params={"rel_tol": 0.1},
p_l=None,
scheme=memtorch.bh.Scheme.DoubleColumn,
tile_shape=None,
max_input_voltage=None,
scaling_routine=naive_scale,
scaling_routine_params={},
source_resistance=None,
line_resistance=None,
ADC_resolution=None,
ADC_overflow_rate=0.0,
quant_method=None,
use_bindings=True,
random_crossbar_init=False,
verbose=True,
**kwargs
):
"""Method to convert a torch.nn model to a memristive model.
Parameters
----------
model : torch.nn.Module
torch.nn.Module to patch.
memristor_model : memtorch.bh.memristor.Memristor.Memristor
Memristor model.
memristor_model_params : **kwargs
Memristor model keyword arguments.
module_parameters_to_patch : module_paramter_patches
Model parameters to patch.
mapping_routine : function
Mapping routine to use.
transistor : bool
Used to determine if a 1T1R (True) or 1R arrangement (False) is simulated.
programming_routine : function
Programming routine to use.
programming_routine_params : **kwargs
Programming routine keyword arguments.
p_l: float
If not None, the proportion of weights to retain.
scheme : memtorch.bh.Scheme
Weight representation scheme.
tile_shape : (int, int)
Tile shape to use to store weights. If None, modular tiles are not used.
max_input_voltage : float
Maximum input voltage used to encode inputs. If None, inputs are unbounded.
scaling_routine : function
Scaling routine to use in order to scale batch inputs.
scaling_routine_params : **kwargs
Scaling routine keyword arguments.
source_resistance : float
The resistance between word/bit line voltage sources and crossbar(s).
line_resistance : float
The interconnect line resistance between adjacent cells.
ADC_resolution : int
ADC resolution (bit width). If None, quantization noise is not accounted for.
ADC_overflow_rate : float
Overflow rate threshold for linear quanitzation (if ADC_resolution is not None).
quant_method:
Quantization method. Must be in ['linear', 'log', 'log_minmax', 'minmax', 'tanh'], or None.
use_bindings : bool
Used to determine if C++/CUDA bindings are used (True) or not (False).
random_crossbar_init : bool
Determines if the crossbar is to be initialized at random values in between Ron and Roff
verbose : bool
Used to determine if verbose output is enabled (True) or disabled (False).
Returns
-------
torch.nn.Module
Patched torch.nn.Module.
"""
def patch_module(target_attr):
parameter_type = str(type(target_attr))
patch = supported_module_parameters.get(parameter_type)
return patch(
target_attr,
memristor_model=memristor_model,
memristor_model_params=memristor_model_params,
mapping_routine=mapping_routine,
transistor=transistor,
programming_routine=programming_routine,
programming_routine_params=programming_routine_params,
p_l=p_l,
scheme=scheme,
tile_shape=tile_shape,
max_input_voltage=max_input_voltage,
scaling_routine=scaling_routine,
scaling_routine_params=scaling_routine_params,
source_resistance=source_resistance,
line_resistance=line_resistance,
ADC_resolution=ADC_resolution,
ADC_overflow_rate=ADC_overflow_rate,
quant_method=quant_method,
use_bindings=use_bindings,
random_crossbar_init=random_crossbar_init,
verbose=verbose,
**kwargs
)
def patch_modules(module, name=""):
for attr_str in dir(module):
target_attr = getattr(module, attr_str)
if any(
isinstance(target_attr, module_parameter)
and not hasattr(target_attr, "transistor")
for module_parameter in module_parameters_to_patch
):
new_bn = patch_module(target_attr)
setattr(module, attr_str, new_bn)
if isinstance(module, torch.nn.Module):
if type(module) == torch.nn.modules.container.Sequential:
for idx, (name, child) in enumerate(module.named_children()):
if any(
isinstance(child, module_parameter)
and not hasattr(child, "transistor")
for module_parameter in module_parameters_to_patch
):
target_attr = module[idx]
new_bn = patch_module(target_attr)
module[idx] = new_bn
else:
patch_modules(child, name)
else:
for name, child in module.named_children():
patch_modules(child, name)
else:
for child in module:
patch_modules(child, name)
patch_modules(model)
def tune_(self, tune_kwargs=None):
"""Method to tune a memristive layer.
Parameters
----------
tune_kwargs : dict
Dictionary of **kwargs for different layer types for .tune().
"""
for _, (name, m) in enumerate(list(self.named_modules())):
if hasattr(m, "tune"):
if tune_kwargs is not None:
module_type = str(type(m))
if module_type in tune_kwargs:
m.tune(**tune_kwargs[module_type])
else:
m.tune()
else:
m.tune()
def forward_legacy(self, enable_forward_legacy):
"""Method to enable or disable forward legacy operation.
Parameters
----------
enable_forward_legacy : bool
Enable or disable forward legacy operation.
"""
for i, (name, m) in enumerate(list(self.named_modules())):
if type(m) in supported_module_parameters.values():
m.forward_legacy_enabled = enable_forward_legacy
def disable_legacy(self):
"""Method to delete all legacy parameters to reduce memory usage. When this method is called forward_legacy is disabled."""
for i, (name, m) in enumerate(list(self.named_modules())):
if type(m) in supported_module_parameters.values():
if type(m) == RNN:
delattr(m, "w_ih")
m.w_ih = None
delattr(m, "w_hh")
m.w_hh = None
if m.bidirectional:
delattr(m, "w_ih_reverse")
m.w_ih_reverse = None
delattr(m, "w_hh_reverse")
m.w_hh_reverse = None
else:
delattr(m, "weight")
m.weight = None
if "cpu" not in memtorch.__version__:
torch.cuda.empty_cache()
self.forward_legacy(False)
delattr(self, "forward_legacy")
def set_cuda_malloc_heap_size(self, cuda_malloc_heap_size):
"""Method to set the CUDA malloc heap size."""
for i, (name, m) in enumerate(list(self.named_modules())):
if type(m) in supported_module_parameters.values():
m.cuda_malloc_heap_size = cuda_malloc_heap_size
model.forward_legacy = forward_legacy.__get__(model)
model.tune_ = tune_.__get__(model)
model.forward_legacy(False)
model.disable_legacy = disable_legacy.__get__(model)
model.set_cuda_malloc_heap_size = set_cuda_malloc_heap_size.__get__(model)
return model