Source code for memtorch.mn.Linear

import math
import warnings

import numpy as np
import torch
import torch.nn as nn

import memtorch
from memtorch.bh.crossbar.Crossbar import init_crossbar, simulate_matmul
from memtorch.bh.crossbar.Tile import tiled_inference
from memtorch.map.Input import naive_scale
from memtorch.map.Module import naive_tune
from memtorch.map.Parameter import naive_map


[docs]class Linear(nn.Linear): """nn.Linear equivalent. Parameters ---------- linear_layer : torch.nn.Linear Linear layer to patch. memristor_model : memtorch.bh.memristor.Memristor.Memristor Memristor model. memristor_model_params : **kwargs Memristor model keyword arguments. mapping_routine : function Mapping routine to use. transistor : bool Used to determine if a 1T1R (True) or 1R arrangement (False) is simulated. programming_routine : function Programming routine to use. programming_routine_params : **kwargs Programming routine keyword arguments. p_l: float If not None, the proportion of weights to retain. scheme : memtorch.bh.Scheme Weight representation scheme. tile_shape : (int, int) Tile shape to use to store weights. If None, modular tiles are not used. max_input_voltage : float Maximum input voltage used to encode inputs. If None, inputs are unbounded. scaling_routine : function Scaling routine to use in order to scale batch inputs. scaling_routine_params : **kwargs Scaling routine keyword arguments. source_resistance : float The resistance between word/bit line voltage sources and crossbar(s). line_resistance : float The interconnect line resistance between adjacent cells. ADC_resolution : int ADC resolution (bit width). If None, quantization noise is not accounted for. ADC_overflow_rate : float Overflow rate threshold for linear quanitzation (if ADC_resolution is not None). quant_method: string Quantization method. Must be in ['linear', 'log', 'log_minmax', 'minmax', 'tanh'], or None. use_bindings : bool Used to determine if C++/CUDA bindings are used (True) or not (False). random_crossbar_init: bool Determines if the crossbar is to be initialized at random values in between Ron and Roff verbose : bool Used to determine if verbose output is enabled (True) or disabled (False). """ def __init__( self, linear_layer, memristor_model, memristor_model_params, mapping_routine=naive_map, transistor=True, programming_routine=None, programming_routine_params={}, p_l=None, scheme=memtorch.bh.Scheme.DoubleColumn, tile_shape=None, max_input_voltage=None, scaling_routine=naive_scale, scaling_routine_params={}, source_resistance=None, line_resistance=None, ADC_resolution=None, ADC_overflow_rate=0.0, quant_method=None, use_bindings=True, random_crossbar_init=False, verbose=True, *args, **kwargs ): assert isinstance( linear_layer, nn.Linear ), "linear_layer is not an instance of nn.Linear." self.device = torch.device("cpu" if "cpu" in memtorch.__version__ else "cuda") self.transistor = transistor self.scheme = scheme self.tile_shape = tile_shape self.max_input_voltage = max_input_voltage self.scaling_routine = scaling_routine self.scaling_routine_params = scaling_routine_params self.source_resistance = source_resistance self.line_resistance = line_resistance self.ADC_resolution = ADC_resolution self.ADC_overflow_rate = ADC_overflow_rate if "cpu" not in memtorch.__version__: self.cuda_malloc_heap_size = 50 else: self.cuda_malloc_heap_size = None if not transistor: assert ( source_resistance is not None and source_resistance >= 0.0 ), "Source resistance is invalid." assert ( line_resistance is not None and line_resistance >= 0.0 ), "Line resistance is invalid." if quant_method in memtorch.bh.Quantize.quant_methods: self.quant_method = quant_method else: self.quant_method = None if quant_method is not None: assert ( ADC_resolution is not None and type(ADC_resolution) == int and ADC_resolution > 0 ), "ADC resolution is invalid." assert ( ADC_overflow_rate is not None ), "ADC_overflow_rate must be specified if quant_method is not None." self.use_bindings = use_bindings self.verbose = verbose self.forward_legacy_enabled = True super(Linear, self).__init__( linear_layer.in_features, linear_layer.out_features, **kwargs ) self.weight.data = linear_layer.weight.data if linear_layer.bias is not None: self.bias.data = linear_layer.bias.data else: self.bias = None self.zero_grad() self.weight.requires_grad = False if linear_layer.bias is not None: self.bias.requires_grad = False self.crossbars, self.crossbar_operation = init_crossbar( weights=self.weight, memristor_model=memristor_model, memristor_model_params=memristor_model_params, transistor=transistor, mapping_routine=mapping_routine, programming_routine=programming_routine, programming_routine_params=programming_routine_params, p_l=p_l, scheme=scheme, tile_shape=tile_shape, use_bindings=use_bindings, cuda_malloc_heap_size=self.cuda_malloc_heap_size, random_crossbar_init=random_crossbar_init, ) self.transform_output = lambda x: x if verbose: print("Patched %s -> %s" % (linear_layer, self))
[docs] def forward(self, input): """Method to perform forward propagations. Parameters ---------- input : torch.Tensor Input tensor. Returns ------- torch.Tensor Output tensor. """ if self.forward_legacy_enabled: out = torch.matmul( input.to(self.device), self.weight.data.T.to(self.device) ) if self.bias is not None: out += self.bias.view(1, -1).expand_as(out) return out else: input = self.scaling_routine(self, input, **self.scaling_routine_params) if hasattr(self, "non_linear"): warnings.warn( "Non-liner modeling does not currently account for source and line resistances." ) if self.tile_shape is not None: tiles_map = self.crossbars[0].tiles_map crossbar_shape = self.weight.data.shape else: tiles_map = None crossbar_shape = None if hasattr(self, "simulate"): nl = False else: nl = True out_ = self.crossbar_operation( self.crossbars, lambda crossbar, input_: simulate_matmul( input, crossbar, nl=nl, tiles_map=tiles_map, crossbar_shape=crossbar_shape, max_input_voltage=self.max_input_voltage, ADC_resolution=self.ADC_resolution, ADC_overflow_rate=self.ADC_overflow_rate, quant_method=self.quant_method, use_bindings=self.use_bindings, ), input_=input, ).to(self.device) else: if self.tile_shape is not None: out_ = tiled_inference(input, self, transistor=self.transistor) else: devices = self.crossbar_operation( self.crossbars, lambda crossbar: crossbar.conductance_matrix ) if self.transistor: out_ = torch.matmul( input.to(self.device), devices, ) else: out_ = memtorch.bh.crossbar.Passive.solve_passive( devices, input.to(self.device), torch.zeros(input.shape[0], devices.shape[1]), self.source_resistance, self.line_resistance, n_input_batches=input.shape[0], use_bindings=self.use_bindings, cuda_malloc_heap_size=self.cuda_malloc_heap_size, ) if self.quant_method is not None: out_ = memtorch.bh.Quantize.quantize( out_, quant=self.ADC_resolution, overflow_rate=self.ADC_overflow_rate, quant_method=self.quant_method, ) out = self.transform_output(out_).to(self.device) if self.bias is not None: out += self.bias.data.view(1, -1).to(self.device).expand_as(out) return out
[docs] def tune(self, input_shape=4098): """Tuning method.""" self.transform_output = naive_tune( self, (input_shape, self.in_features), self.verbose )
def __str__(self): return "bh.Linear(in_features=%d, out_features=%d, bias=%s)" % ( self.in_features, self.out_features, not self.bias is None, )