Source code for Fireworks.core.component_map

import torch
from .message import Message
from torch.nn import Parameter, Module

[docs]class Component_Map(dict): """ Each of the main objects that can be used to construct pipelines (Pipes, Junctions, and Models) have a means for tracking their internal state. For Pipes, this is handled by a simple dictionary, but for Junctions and Models we use this class which satisfies the more complex needs for these objects. In particular, a Component_Map can track whether a variable is internal or external to a given object, and if it's external, whom the variable belongs to. This lets us dynamically assign a variable of one object to to a component of a Junction or Model while maintaining the distinction that the assigned variable is not internal to Junction or Model. This distinction can be useful for variables such as hyperparameters or runtime configurations (eg. whether to use Cuda) that one does not want to store alongside variables like model weights. You can also have a Model 'borrow' variables from another Model while maintaining this distinciton (eg. use the first two layers from this other model, then use the remainder of the layers using internal weights), and this can be useful when training Models (you could have your optimizer operate only on a Model's internal parameters, treating everything else as constant.) These are a just few examples of how this abstraction can be useful, and in simpler terms, it is essentially a means to deliberately pass variables by reference, which is not how Python's memory model operates by default, but it can be extremely helpful when doing machine learning. The details of the interaction with a Component_Map are abstracted away by Junctions and Models. Hence, you shouldn't have to directly interact with a Component_Map. Instead, you can generally just call set_state and get_state on Junctions and Models to get serialized representations of the Component_Maps. The format of these serialization is a dict of the form {'external': {...}, 'internal': {...}}. The 'internal' dict contains a mapping between variable names and those variables. The 'external' dict contains a mapping between variable names and the object that those variables belong to. In this way, a Component_Map can keep track of the owner of the linked variable and also get its value as needed. Hence, Junctions and Models can simply use that variable as if it were internal, and this makes it easy to swap variables around without changing syntax (eg. replace some internal component of a Model with an attribute of some other object on the fly.) A Component_Map behaves like a dict with the special property that if you assign an tuple of the form (obj, x) to the dict, where x is a string, then the Component_Map will treat that as a 'pass by reference' assignment. In other words, it will assume that you want to externally link the variable obj.x to the Component_Map. For example, if you do this: :: A = some_object() cm = Component_Map() cm['a'] = (A, 'x') Now whenever you call cm['a'], you will get whatever is returned by A.x. :: cm['a'] == A.x # This evaluates to True. cm['a'] is A.x # This also evaluates to True, because the assignment is by reference. If you cm.get_state(), the 'external' dict will contain a reference to A. :: state = cm.get_state() external = state['external'] external['a'] == (A, 'x') # This evaluates to True. On the other hand, if you do this: :: cm['a'] = A.x # Don't pass by reference. cm['a'] == A.x # This evaluates to True. cm['a'] is A.x # This may or may not be True because Python sometimes assigns by reference and sometimes copies data depending on the situation. This will be treated as an internal assignment. Note that PyTorch implements logic for enforcing pass-by-reference for torch.nn.Parameter objects. Hence, if A.x was a Parameter, then the assignment will be by reference. However, we will have no way of knowing who the 'owner' of the Parameter is, and by using Component_Maps, we also are able to extend this functionality to any Python object. If you now get the state, it will be in the 'internal' dict. :: state = cm.get_state() internal = state['internal'] internal['a'] == A.x # This evaluates to True. If A.x is vector/tensor-valued, you may get a vector/tensor of 1's. """ def __init__(self, components): """ self._internal_components functions like a normal dictionary. If you call cm['a'] and 'a' is in the internal_components dict, then it will return cm._internal_components['a']. However, if 'a' is external, then cm['a'] will return getattr(cm._external_modules['a'], c._external_attribute_names['a']). """ dict.__init__(self) self._external_modules = {} self._external_attribute_names = {} self._internal_components = {} for key, value in components.items(): self[key] = value def __setitem__(self, key, val): """ This overrides the __setitem__ method of dict so that if you set an item of the form (obj, x) to a key k, where x is a string, the Component_Map will 'link' obj.x to k. ie. cm['k'] == obj.x. It will do this by inserting obj into the _external_modules dict and x into the _external_attribute_names dict with k as the key for both. Hence, when you call cm['k'], the Component_Map will fetch obj and the attribute name x and then call getattr(obj, x). Note that you cannot have a key in both the internal and external dicts. If you already have a key in one and you assign to the other, the former will be deleted. This prevents ambiguity when accessing elements. """ if type(val) is tuple and len(val) is 2 and type(val[1]) is str and hasattr(val[0], val[1]): # Very specific test to check if the intention is to link an attribute inside of another # object to this Component Map rather than simply set the value of the key to a tuple. if key in self._internal_components: # Delete from internal components if this key already exists del self[key] obj, attribute = val value = getattr(obj, attribute) # key, value = self.setitem_hook(key, value) # self._external_components[key] = value self._external_modules[key] = obj self._external_attribute_names[key] = attribute else: if key in self._external_modules: # Deelte from external components if this key already exists. del self[key] value = val key, value = self.setitem_hook(key, value) self._internal_components[key] = value dict.__setitem__(self, key, value) def __delitem__(self, key): """ This also deletes the key from the the internal and external dicts. """ try: dict.__delitem__(self, key) for d in [self._external_attribute_names, self._external_modules, self._internal_components]: if key in d: del d[key] except KeyError: raise KeyError def __getitem__(self, key): """ This returns the object referenced by the key, whether it is internal or external. """ if key in self._internal_components: return self._internal_components[key] elif key in self._external_modules: return getattr(self._external_modules[key], self._external_attribute_names[key]) else: raise AttributeError()
[docs] def setitem_hook(self, key, value): """ This can be overridden by a subclass in order to implement specific actions that should take place before an attribute is set. """ return key, value
[docs] def set_state(self, state): """ This method can be used to apply a serialized representation of state to a Component_Map at once. This is used for loading in saved data. Args: state: A dict of the form {'external': {...}, 'internal': {...}}. The elements of this dict will be assigned to the Component_Map. Note that this will not reset the Component_Map, so if there were previous elements already present, those will remain in the Component_Map. """ for key, value in {**state['internal'], **state['external']}.items(): self[key] = value
[docs] def get_state(self): """ This returns a serialized representation of the state of the Component_Map. Args: None Returns: state: A dict of the form {'external': {...}, 'internal': {...}}. See above documentation for more information. """ internal = self._internal_components external = {key: (self._external_modules[key], self._external_attribute_names[key]) for key in self._external_modules} return {'internal': internal, 'external': external}
[docs]class PyTorch_Component_Map(Component_Map): """ This is a subclass of Component_Map with additional functionality for dealing with PyTorch data structures. PyTorch has a lot of logic in the background to keep track of Parameters and gradients and where objects are located in memory. The PyTorch_Component_Map has a modified __setitem__ method which ensures that there are no conflicts with any of these background operations by PyTorch. In particular, a PyTorch_Component_Map can have a (PyTorch) Model assigned to it, and whenever __setitem__ is called, the item is 1) Converted to a torch.nn.Parameter object if possible. This is essential for computing gradients and training the parameter. 2) Recursively assigned if necessary. This concept is best explained with an example. Say you have a neural network with a convolutional layer, :: model = some_pytorch_model() model.conv1 = torch.nn.Conv2d(4,4,4) # This represents a 4x4 convolutional layer with 4 channels. 'model.conv1' is itself a PyTorch Module with its own internal state, and in general, models can have models that have models, and so on. In other words, 'model.conv1' could itself have variables that are Modules/Models and so on. When you get the state dict for the original model, you will get nested dictionaries. These can still be serialized and saved to a file like normal, but when we call set_state, we want to make sure that we assign these nested dictionary elements to the correct submodules. :: state = model.get_state() internal = state['internal'] internal['conv1'] == {'weights': ['This is some Tensor'], 'bias': ['This is some vector']} If we naively called model.set_state(state) to load some other state from a file, then we would end up assigning a nested dictionary to the value of model.conv1. What we actually want is: :: model.set_state(state) print(model.conv1) # This is a PyTorch Module print(model.conv1.weights) # This is a Tensor print(model.conv1.bias) # This is a Tensor PyTorch_Component_Map checks if the attribute being assigned to is a PyTorch_Model or (PyTorch) Module and performs this type of assignment. 3) 'Registered' to the Model. This is something that PyTorch does whenever you assign a value to a PyTorch Module and is essential for proper functioning of PyTorch methods/functions, such as getting a state_dict, submodules, etc. This additional logic is important, because in general, all of the layers of a Neural Network are implemented as Modules and PyTorch Modules inherently has a nested structure. """ def __init__(self, components, model=None): """ If a model is provided, then the PyTorch_Component_Map can register values with that model, which is essential for proper usage with PyTorch Modules. """ self.model = model Component_Map.__init__(self, components) def __setitem__(self, key, val): """ This method has additional logic to __setitem__ that is described above. """ # This allows the Module superclass to register the parameters. if self.model is not None and key != 'components' and hasattr(self.model, key) and isinstance(val, dict): # If setting the state dict for a submodule. submodule = getattr(self.model, key) if hasattr(submodule, 'set_state'): # Is a Model submodule.set_state(val) elif isinstance(submodule, Module): # Convert state dict to a dict of tensors val = {k:torch.Tensor(v) for k,v in val.items()} submodule.load_state_dict(val) elif isinstance(submodule, dict): # It was supposed to be a dict Component_Map.__setitem__(self, key, val) if self.model is not None: val = self[key] i = self.model._flags['components_initialized'] self.model._flags['components_initialized'] = 0 setattr(self.model, key, val) self.model._flags['components_initialized'] = i else: Component_Map.__setitem__(self, key, val) if self.model is not None: val = self[key] i = self.model._flags['components_initialized'] self.model._flags['components_initialized'] = 0 setattr(self.model, key, val) self.model._flags['components_initialized'] = i
[docs] def hook(self, key, value): """ This is used to (try to) convert objects to torch.nn.Parameter objects upon assignment. If a value has tensor-like structure (ie. is a list or ndarray), then it will automatically be converted. """ if not isinstance(value, Parameter) and not isinstance(value, Module) and hasattr(value, '__len__'): # Convert to Parameter try: value = Parameter(torch.Tensor(value)) except: # If the component is not a tensor-like, Parameter, or Module, then it is some other object that we simply attach to the model # For example, it could be a Pipe or Junction that the model can call upon. pass return key, value
[docs] def setitem_hook(self, key, value): """ This assigns the above hook to setitem_hook so that it will be triggered upon every __setitem__ call. """ return self.hook(key, value)