import math
from enum import Enum, auto, unique
import numpy as np
import torch
import memtorch
import memtorch.mn
from memtorch.bh.nonideality.DeviceFaults import apply_device_faults
from memtorch.bh.nonideality.Endurance import apply_endurance_model
from memtorch.bh.nonideality.FiniteConductanceStates import (
apply_finite_conductance_states,
)
from memtorch.bh.nonideality.NonLinear import apply_non_linear
from memtorch.bh.nonideality.Retention import apply_retention_model
from memtorch.mn.Module import supported_module_parameters
[docs]@unique
class NonIdeality(Enum):
"""NonIdeality enumeration."""
FiniteConductanceStates = auto()
DeviceFaults = auto()
NonLinear = auto()
Endurance = auto()
Retention = auto()
[docs]def apply_nonidealities(model, non_idealities, **kwargs):
"""Method to apply non-idealities to a torch.nn.Module instance with memristive layers.
Parameters
----------
model : torch.nn.Module
torch.nn.Module instance.
nonidealities : memtorch.bh.nonideality.NonIdeality.NonIdeality, tuple
Non-linearitites to model.
Returns
-------
torch.nn.Module
Patched instance.
"""
def apply_patched_module(model, patched_module, name, m):
model._modules[name] = patched_module
return model
for _, (name, m) in enumerate(list(model.named_modules())):
if type(m) in supported_module_parameters.values():
for non_ideality in non_idealities:
if non_ideality == NonIdeality.FiniteConductanceStates:
required(
kwargs,
["conductance_states"],
"memtorch.bh.nonideality.NonIdeality.FiniteConductanceStates",
)
model = apply_patched_module(
model,
apply_finite_conductance_states(
m, kwargs["conductance_states"]
),
name,
m,
)
elif non_ideality == NonIdeality.DeviceFaults:
required(
kwargs,
["lrs_proportion", "hrs_proportion", "electroform_proportion"],
"memtorch.bh.nonideality.NonIdeality.DeviceFaults",
)
model = apply_patched_module(
model,
apply_device_faults(
m,
kwargs["lrs_proportion"],
kwargs["hrs_proportion"],
kwargs["electroform_proportion"],
),
name,
m,
)
elif non_ideality == NonIdeality.NonLinear:
if "simulate" in kwargs:
if kwargs["simulate"] == True:
model = apply_patched_module(
model, apply_non_linear(m, simulate=True), name, m
)
else:
required(
kwargs,
[
"sweep_duration",
"sweep_voltage_signal_amplitude",
"sweep_voltage_signal_frequency",
],
"memtorch.bh.nonideality.NonIdeality.NonLinear",
)
model = apply_patched_module(
model,
apply_non_linear(
m,
kwargs["sweep_duration"],
kwargs["sweep_voltage_signal_amplitude"],
kwargs["sweep_voltage_signal_frequency"],
),
name,
m,
)
else:
required(
kwargs,
[
"sweep_duration",
"sweep_voltage_signal_amplitude",
"sweep_voltage_signal_frequency",
],
"memtorch.bh.nonideality.NonIdeality.NonLinear",
)
model = apply_patched_module(
model,
apply_non_linear(
m,
kwargs["sweep_duration"],
kwargs["sweep_voltage_signal_amplitude"],
kwargs["sweep_voltage_signal_frequency"],
),
name,
m,
)
elif non_ideality == NonIdeality.Endurance:
required(
kwargs,
["x", "endurance_model", "endurance_model_kwargs"],
"memtorch.bh.nonideality.Endurance",
)
model = apply_patched_module(
model,
apply_endurance_model(
layer=m,
x=kwargs["x"],
endurance_model=kwargs["endurance_model"],
**kwargs["endurance_model_kwargs"]
),
name,
m,
)
elif non_ideality == NonIdeality.Retention:
required(
kwargs,
["time", "retention_model", "retention_model_kwargs"],
"memtorch.bh.nonideality.Retention",
)
model = apply_patched_module(
model,
apply_retention_model(
layer=m,
time=kwargs["time"],
retention_model=kwargs["retention_model"],
**kwargs["retention_model_kwargs"]
),
name,
m,
)
return model
[docs]def required(kwargs, arguments, call):
"""Method to check is required arguments in **kwargs are present.
Parameters
----------
kwargs : **kwargs
Keyword-arguments.
arguments : list of str
Arguments which are required to be present.
call : str
Function to call.
"""
for argument in arguments:
assert kwargs[argument] is not None, "%s is required when calling %s" % (
argument,
call,
)