import math
import warnings
import numpy as np
import torch
import torch.nn as nn
import memtorch
from memtorch.bh.crossbar.Crossbar import init_crossbar, simulate_matmul
from memtorch.bh.crossbar.Tile import tiled_inference
from memtorch.map.Input import naive_scale
from memtorch.map.Module import naive_tune
from memtorch.map.Parameter import naive_map
[docs]class Conv3d(nn.Conv3d):
"""nn.Conv3d equivalent.
Parameters
----------
convolutional_layer : torch.nn.Conv3d
Convolutional layer to patch.
memristor_model : memtorch.bh.memristor.Memristor.Memristor
Memristor model.
memristor_model_params : **kwargs
Memristor model keyword arguments.
mapping_routine : function
Mapping routine to use.
transistor : bool
Used to determine if a 1T1R (True) or 1R arrangement (False) is simulated.
programming_routine : function
Programming routine to use.
programming_routine_params : **kwargs
Programming routine keyword arguments.
p_l: float
If not None, the proportion of weights to retain.
scheme : memtorch.bh.Scheme
Weight representation scheme.
tile_shape : (int, int)
Tile shape to use to store weights. If None, modular tiles are not used.
max_input_voltage : float
Maximum input voltage used to encode inputs. If None, inputs are unbounded.
scaling_routine : function
Scaling routine to use in order to scale batch inputs.
scaling_routine_params : **kwargs
Scaling routine keyword arguments.
source_resistance : float
The resistance between word/bit line voltage sources and crossbar(s).
line_resistance : float
The interconnect line resistance between adjacent cells.
ADC_resolution : int
ADC resolution (bit width). If None, quantization noise is not accounted for.
ADC_overflow_rate : float
Overflow rate threshold for linear quanitzation (if ADC_resolution is not None).
quant_method: string
Quantization method. Must be in ['linear', 'log', 'log_minmax', 'minmax', 'tanh'], or None.
use_bindings : bool
Used to determine if C++/CUDA bindings are used (True) or not (False).
random_crossbar_init: bool
Determines if the crossbar is to be initialized at random values in between Ron and Roff
verbose : bool
Used to determine if verbose output is enabled (True) or disabled (False).
"""
def __init__(
self,
convolutional_layer,
memristor_model,
memristor_model_params,
mapping_routine=naive_map,
transistor=True,
programming_routine=None,
programming_routine_params={},
p_l=None,
scheme=memtorch.bh.Scheme.DoubleColumn,
tile_shape=None,
max_input_voltage=None,
scaling_routine=naive_scale,
scaling_routine_params={},
source_resistance=None,
line_resistance=None,
ADC_resolution=None,
ADC_overflow_rate=0.0,
quant_method=None,
use_bindings=True,
random_crossbar_init=False,
verbose=True,
*args,
**kwargs
):
assert isinstance(
convolutional_layer, nn.Conv3d
), "convolutional_layer is not an instance of nn.Conv3d."
assert (
convolutional_layer.groups != 2
), "groups=2 is not currently supported for convolutional layers."
self.device = torch.device("cpu" if "cpu" in memtorch.__version__ else "cuda")
self.transistor = transistor
self.scheme = scheme
self.tile_shape = tile_shape
self.max_input_voltage = max_input_voltage
self.scaling_routine = scaling_routine
self.scaling_routine_params = scaling_routine_params
self.source_resistance = source_resistance
self.line_resistance = line_resistance
self.ADC_resolution = ADC_resolution
self.ADC_overflow_rate = ADC_overflow_rate
if "cpu" not in memtorch.__version__:
self.cuda_malloc_heap_size = 50
else:
self.cuda_malloc_heap_size = None
if not transistor:
assert (
source_resistance is not None and source_resistance >= 0.0
), "Source resistance is invalid."
assert (
line_resistance is not None and line_resistance >= 0.0
), "Line resistance is invalid."
if quant_method in memtorch.bh.Quantize.quant_methods:
self.quant_method = quant_method
else:
self.quant_method = None
if quant_method is not None:
assert (
ADC_resolution is not None
and type(ADC_resolution) == int
and ADC_resolution > 0
), "ADC resolution is invalid."
assert (
ADC_overflow_rate is not None
), "ADC_overflow_rate must be specified if quant_method is not None."
self.use_bindings = use_bindings
self.verbose = verbose
self.forward_legacy_enabled = True
super(Conv3d, self).__init__(
convolutional_layer.in_channels,
convolutional_layer.out_channels,
convolutional_layer.kernel_size,
stride=convolutional_layer.stride,
padding=convolutional_layer.padding,
dilation=convolutional_layer.dilation,
groups=convolutional_layer.groups,
**kwargs
)
self.weight.data = convolutional_layer.weight.data
if convolutional_layer.bias is not None:
self.bias.data = convolutional_layer.bias.data
self.zero_grad()
self.weight.requires_grad = False
if convolutional_layer.bias is not None:
self.bias.requires_grad = False
self.crossbars, self.crossbar_operation = init_crossbar(
weights=self.weight,
memristor_model=memristor_model,
memristor_model_params=memristor_model_params,
transistor=transistor,
mapping_routine=mapping_routine,
programming_routine=programming_routine,
programming_routine_params=programming_routine_params,
p_l=p_l,
scheme=scheme,
tile_shape=tile_shape,
use_bindings=use_bindings,
cuda_malloc_heap_size=self.cuda_malloc_heap_size,
random_crossbar_init=random_crossbar_init,
)
self.transform_output = lambda x: x
if verbose:
print("Patched %s -> %s" % (convolutional_layer, self))
[docs] def forward(self, input):
"""Method to perform forward propagations.
Parameters
----------
input : torch.Tensor
Input tensor.
Returns
-------
torch.Tensor
Output tensor.
"""
if self.forward_legacy_enabled:
return torch.nn.functional.conv3d(
input.to(self.device),
self.weight.to(self.device),
bias=self.bias,
stride=self.stride,
padding=self.padding,
)
else:
output_dim = [0, 0, 0]
output_dim[0] = (
int(
(input.shape[2] - self.kernel_size[0] + 2 * self.padding[0])
/ self.stride[0]
)
+ 1
)
output_dim[1] = (
int(
(input.shape[3] - self.kernel_size[1] + 2 * self.padding[1])
/ self.stride[1]
)
+ 1
)
output_dim[2] = (
int(
(input.shape[4] - self.kernel_size[2] + 2 * self.padding[2])
/ self.stride[2]
)
+ 1
)
out = torch.zeros(
(
input.shape[0],
self.out_channels,
output_dim[0],
output_dim[1],
output_dim[2],
)
).to(self.device)
for batch in range(input.shape[0]):
if not all(item == 0 for item in self.padding):
batch_input = nn.functional.pad(
input[batch],
pad=(
self.padding[2],
self.padding[2],
self.padding[1],
self.padding[1],
self.padding[0],
self.padding[0],
),
)
else:
batch_input = input[batch]
batch_input = self.scaling_routine(
self, batch_input, **self.scaling_routine_params
)
unfolded_batch_input = (
batch_input.unfold(1, self.kernel_size[0], self.stride[0])
.unfold(2, self.kernel_size[1], self.stride[1])
.unfold(3, self.kernel_size[2], self.stride[2])
.permute(1, 2, 3, 0, 4, 5, 6)
.reshape(
-1,
(self.in_channels // self.groups)
* self.kernel_size[0]
* self.kernel_size[1]
* self.kernel_size[2],
)
)
if hasattr(self, "non_linear"):
warnings.warn(
"Non-liner modeling does not currently account for source and line resistances."
)
if self.tile_shape is not None:
tiles_map = self.crossbars[0].tiles_map
crossbar_shape = (
self.crossbars[0].rows,
self.crossbars[0].columns,
)
else:
tiles_map = None
crossbar_shape = None
if hasattr(self, "simulate"):
nl = False
else:
nl = True
out_ = (
self.crossbar_operation(
self.crossbars,
lambda crossbar, input_: simulate_matmul(
unfolded_batch_input,
crossbar,
nl=nl,
tiles_map=tiles_map,
crossbar_shape=crossbar_shape,
max_input_voltage=self.max_input_voltage,
ADC_resolution=self.ADC_resolution,
ADC_overflow_rate=self.ADC_overflow_rate,
quant_method=self.quant_method,
use_bindings=self.use_bindings,
),
input_=unfolded_batch_input,
)
.to(self.device)
.T
)
else:
if self.tile_shape is not None:
out_ = tiled_inference(
unfolded_batch_input, self, transistor=self.transistor
).T
else:
devices = self.crossbar_operation(
self.crossbars,
lambda crossbar: crossbar.conductance_matrix,
)
if self.transistor:
out_ = torch.matmul(
unfolded_batch_input,
devices,
).T
else:
out_ = memtorch.bh.crossbar.Passive.solve_passive(
devices,
unfolded_batch_input.to(self.device),
torch.zeros(
unfolded_batch_input.shape[0], devices.shape[1]
),
self.source_resistance,
self.line_resistance,
n_input_batches=unfolded_batch_input.shape[0],
use_bindings=self.use_bindings,
cuda_malloc_heap_size=self.cuda_malloc_heap_size,
).T
if self.quant_method is not None:
out_ = memtorch.bh.Quantize.quantize(
out_,
quant=self.ADC_resolution,
overflow_rate=self.ADC_overflow_rate,
quant_method=self.quant_method,
)
out[batch] = out_.view(size=(1, self.out_channels, *output_dim))
out = self.transform_output(out)
if not self.bias is None:
out[batch] += self.bias.data.view(-1, 1, 1, 1).expand_as(out[batch])
return out
[docs] def tune(self, input_batch_size=4, input_shape=32):
"""Tuning method."""
self.transform_output = naive_tune(
self,
(
input_batch_size,
(self.in_channels // self.groups),
input_shape,
input_shape,
input_shape,
),
self.verbose,
)
def __str__(self):
return (
"bh.Conv3d(in_channels=%d, out_channels=%d, kernel_size=(%d, %d, %d), stride=(%d, %d, %d), padding=(%d, %d, %d))"
% (
self.in_channels,
self.out_channels,
self.kernel_size[0],
self.kernel_size[1],
self.kernel_size[2],
self.stride[0],
self.stride[1],
self.stride[2],
self.padding[0],
self.padding[1],
self.padding[2],
)
)