Module flashy.adversarial
For training adversarial losses, we provide an AdversarialLoss wrapper that encapsulate the training of the adversarial loss. This allows us to keep the main training loop simple and to encapsulate the complexity of the adversarial loss training inside this utility class.
Expand source code
"""For training adversarial losses, we provide an AdversarialLoss wrapper that encapsulate
the training of the adversarial loss. This allows us to keep the main training loop simple and
to encapsulate the complexity of the adversarial loss training inside this utility class.
"""
from contextlib import contextmanager
import typing as tp
import torch
from torch import nn
from torch.nn import functional as F
from . import distrib
@contextmanager
def readonly(model: nn.Module):
"""Temporarily switches off gradient computation for the given model.
"""
state = []
for p in model.parameters():
state.append(p.requires_grad)
p.requires_grad_(False)
try:
yield
finally:
for p, s in zip(model.parameters(), state):
p.requires_grad_(s)
LossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]]
class AdversarialLoss(nn.Module):
"""
This is an example class for handling adversarial losses without requiring to mess up
the main training loop. This will not fit all use case and will need inheriting
to extend to more complex use (i.e. gradient penalty, or feature loss).
Args:
adversary: this will be used to estimate the logits given the fake and real samples.
We use the convention that the output is high for fake sample.
optimizer: optimizer used for training the given module.
loss: loss function, by default binary_cross_entropy_with_logits.
Example of usage:
adv_loss = AdversarialLoss(module, optimizer, loss)
for real in loader:
noise = torch.randn(...)
fake = model(noise)
adv_loss.train_adv(fake, real)
loss = adv_loss(fake)
loss.backward()
"""
def __init__(self, adversary: nn.Module, optimizer: torch.optim.Optimizer,
loss: LossType = F.binary_cross_entropy_with_logits):
super().__init__()
self.adversary = adversary
distrib.broadcast_weights(adversary.parameters())
self.optimizer = optimizer
self.loss = loss
def _save_to_state_dict(self, destination, prefix, keep_vars):
# Add the optimizer state dict inside our own.
super()._save_to_state_dict(destination, prefix, keep_vars)
destination[prefix + 'optimizer'] = self.optimizer.state_dict()
return destination
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
# Load optimizer state.
self.optimizer.load_state_dict(state_dict.pop(prefix + 'optimizer'))
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
def train_adv(self, fake: torch.Tensor, real: torch.Tensor):
"""Train the adversary with the given fake and real example.
This will automatically synchronize gradients (with `flashy.distrib.eager_sync_grad`)
and call the optimizer.
"""
logit_fake_is_fake = self.adversary(fake.detach())
logit_real_is_fake = self.adversary(real.detach())
one = torch.tensor(1., device=fake.device).expand_as(logit_fake_is_fake)
zero = torch.tensor(0., device=fake.device).expand_as(logit_real_is_fake)
loss = self.loss(logit_fake_is_fake, one) + self.loss(logit_real_is_fake, zero)
self.optimizer.zero_grad()
with distrib.eager_sync_grad(self.adversary.parameters()):
loss.backward()
self.optimizer.step()
return loss
def forward(self, fake: torch.Tensor):
"""Return the loss for the generator, i.e. trying to fool the adversary.
"""
with readonly(self.adversary):
logit_fake_is_fake = self.adversary(fake)
zero = torch.tensor(0., device=fake.device).expand_as(logit_fake_is_fake)
loss_generator = self.loss(logit_fake_is_fake, zero)
return loss_generator
Functions
def readonly(model: torch.nn.modules.module.Module)
-
Temporarily switches off gradient computation for the given model.
Expand source code
@contextmanager def readonly(model: nn.Module): """Temporarily switches off gradient computation for the given model. """ state = [] for p in model.parameters(): state.append(p.requires_grad) p.requires_grad_(False) try: yield finally: for p, s in zip(model.parameters(), state): p.requires_grad_(s)
Classes
class AdversarialLoss (adversary: torch.nn.modules.module.Module, optimizer: torch.optim.optimizer.Optimizer, loss: Union[torch.nn.modules.module.Module, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = <function binary_cross_entropy_with_logits>)
-
This is an example class for handling adversarial losses without requiring to mess up the main training loop. This will not fit all use case and will need inheriting to extend to more complex use (i.e. gradient penalty, or feature loss).
Args
adversary
- this will be used to estimate the logits given the fake and real samples. We use the convention that the output is high for fake sample.
optimizer
- optimizer used for training the given module.
loss
- loss function, by default binary_cross_entropy_with_logits.
Example of usage:
adv_loss = AdversarialLoss(module, optimizer, loss) for real in loader: noise = torch.randn(...) fake = model(noise) adv_loss.train_adv(fake, real) loss = adv_loss(fake) loss.backward()
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class AdversarialLoss(nn.Module): """ This is an example class for handling adversarial losses without requiring to mess up the main training loop. This will not fit all use case and will need inheriting to extend to more complex use (i.e. gradient penalty, or feature loss). Args: adversary: this will be used to estimate the logits given the fake and real samples. We use the convention that the output is high for fake sample. optimizer: optimizer used for training the given module. loss: loss function, by default binary_cross_entropy_with_logits. Example of usage: adv_loss = AdversarialLoss(module, optimizer, loss) for real in loader: noise = torch.randn(...) fake = model(noise) adv_loss.train_adv(fake, real) loss = adv_loss(fake) loss.backward() """ def __init__(self, adversary: nn.Module, optimizer: torch.optim.Optimizer, loss: LossType = F.binary_cross_entropy_with_logits): super().__init__() self.adversary = adversary distrib.broadcast_weights(adversary.parameters()) self.optimizer = optimizer self.loss = loss def _save_to_state_dict(self, destination, prefix, keep_vars): # Add the optimizer state dict inside our own. super()._save_to_state_dict(destination, prefix, keep_vars) destination[prefix + 'optimizer'] = self.optimizer.state_dict() return destination def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): # Load optimizer state. self.optimizer.load_state_dict(state_dict.pop(prefix + 'optimizer')) super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) def train_adv(self, fake: torch.Tensor, real: torch.Tensor): """Train the adversary with the given fake and real example. This will automatically synchronize gradients (with `flashy.distrib.eager_sync_grad`) and call the optimizer. """ logit_fake_is_fake = self.adversary(fake.detach()) logit_real_is_fake = self.adversary(real.detach()) one = torch.tensor(1., device=fake.device).expand_as(logit_fake_is_fake) zero = torch.tensor(0., device=fake.device).expand_as(logit_real_is_fake) loss = self.loss(logit_fake_is_fake, one) + self.loss(logit_real_is_fake, zero) self.optimizer.zero_grad() with distrib.eager_sync_grad(self.adversary.parameters()): loss.backward() self.optimizer.step() return loss def forward(self, fake: torch.Tensor): """Return the loss for the generator, i.e. trying to fool the adversary. """ with readonly(self.adversary): logit_fake_is_fake = self.adversary(fake) zero = torch.tensor(0., device=fake.device).expand_as(logit_fake_is_fake) loss_generator = self.loss(logit_fake_is_fake, zero) return loss_generator
Ancestors
- torch.nn.modules.module.Module
Class variables
var dump_patches : bool
var training : bool
Methods
def forward(self, fake: torch.Tensor) ‑> Callable[..., Any]
-
Return the loss for the generator, i.e. trying to fool the adversary.
Expand source code
def forward(self, fake: torch.Tensor): """Return the loss for the generator, i.e. trying to fool the adversary. """ with readonly(self.adversary): logit_fake_is_fake = self.adversary(fake) zero = torch.tensor(0., device=fake.device).expand_as(logit_fake_is_fake) loss_generator = self.loss(logit_fake_is_fake, zero) return loss_generator
def train_adv(self, fake: torch.Tensor, real: torch.Tensor)
-
Train the adversary with the given fake and real example. This will automatically synchronize gradients (with
eager_sync_grad()
) and call the optimizer.Expand source code
def train_adv(self, fake: torch.Tensor, real: torch.Tensor): """Train the adversary with the given fake and real example. This will automatically synchronize gradients (with `flashy.distrib.eager_sync_grad`) and call the optimizer. """ logit_fake_is_fake = self.adversary(fake.detach()) logit_real_is_fake = self.adversary(real.detach()) one = torch.tensor(1., device=fake.device).expand_as(logit_fake_is_fake) zero = torch.tensor(0., device=fake.device).expand_as(logit_real_is_fake) loss = self.loss(logit_fake_is_fake, one) + self.loss(logit_real_is_fake, zero) self.optimizer.zero_grad() with distrib.eager_sync_grad(self.adversary.parameters()): loss.backward() self.optimizer.step() return loss