from abc import abstractmethod
import tensorflow as tf
import numpy as np
from .losses import *
__all__ = ['ModelDes', 'BaseModel', 'GANBaseModel']
[docs]class ModelDes(object):
""" base model for ModelDes """
[docs] def ex_init_model(self, dataflow, trainer):
self.trainer = trainer
# may move in create_graph
try:
self.im_height = dataflow.im_size[0]
self.im_width = dataflow.im_size[1]
except AttributeError:
pass
# self.im_height = self.im_height
# self.im_width = self.im_height
try:
self.num_channels = dataflow.num_channels
except AttributeError:
pass
# self.num_channels = self.num_channels
@property
def get_global_step(self):
return self.trainer.get_global_step
[docs] def set_batch_size(self, val):
self._batch_size = val
[docs] def get_batch_size(self):
return self._batch_size
[docs] def set_is_training(self, is_training=True):
self.is_training = is_training
[docs] def get_train_placeholder(self):
default_plh = self._get_train_placeholder()
if not isinstance(default_plh, list):
default_plh = [default_plh]
try:
return self._train_plhs + default_plh
except AttributeError:
return default_plh
def _get_train_placeholder(self):
return []
[docs] def set_train_placeholder(self, plhs=None):
if not isinstance(plhs, list):
plhs = [plhs]
self._train_plhs = plhs
# TODO to be modified
[docs] def get_prediction_placeholder(self):
default_plh = self._get_prediction_placeholder()
if not isinstance(default_plh, list):
default_plh = [default_plh]
try:
return self._predict_plhs + default_plh
except AttributeError:
return default_plh
def _get_prediction_placeholder(self):
return []
[docs] def set_prediction_placeholder(self, plhs=None):
if not isinstance(plhs, list):
plhs = [plhs]
self._predict_plhs = plhs
[docs] def get_graph_feed(self):
return self._get_graph_feed()
def _get_graph_feed(self):
""" return keep_prob feed when dropout is set """
try:
if self.is_training:
feed = {self._dropout_pl: self._keep_prob}
else:
feed = {self._dropout_pl: 1}
return feed
except AttributeError:
return {}
[docs] def set_dropout(self, dropout_placeholder, keep_prob=0.5):
self._dropout_pl = dropout_placeholder
self._keep_prob = keep_prob
[docs] def create_graph(self):
# self._create_graph()
# self._setup_graph()
self._create_input()
self._create_model()
self._ex_setup_graph()
[docs] def create_model(self, inputs=None):
print('**[warning]** consider use dictionary input.')
""" only called when defined inside other model"""
assert inputs is not None, 'inputs cannot be None!'
if not isinstance(inputs, list):
inputs = [inputs]
self._input = inputs
self._create_model()
@abstractmethod
def _create_model(self):
raise NotImplementedError()
@abstractmethod
def _create_input(self):
raise NotImplementedError()
@property
def model_input(self):
try:
return self._input
except AttributeError:
raise AttributeError
# def _get_model_input(self):
# return []
@abstractmethod
def _create_graph(self):
raise NotImplementedError()
def _ex_setup_graph(self):
pass
# def _setup_graph(self):
# pass
# TDDO move outside of class
# summary will be created before prediction
# which is unnecessary
[docs] def setup_summary(self):
self._setup_summary()
def _setup_summary(self):
pass
[docs]class BaseModel(ModelDes):
""" Model with single loss and single optimizer """
[docs] def get_optimizer(self):
try:
return self.optimizer
except AttributeError:
self.optimizer = self._get_optimizer()
return self.optimizer
@property
def default_collection(self):
return 'default'
def _get_optimizer(self):
raise NotImplementedError()
[docs] def get_loss(self):
try:
return self._loss
except AttributeError:
self._loss = self._get_loss()
tf.summary.scalar('loss_summary', self.get_loss(),
collections = [self.default_collection])
return self._loss
def _get_loss(self):
raise NotImplementedError()
[docs] def get_grads(self):
try:
return self.grads
except AttributeError:
optimizer = self.get_optimizer()
loss = self.get_loss()
self.grads = optimizer.compute_gradients(loss)
[tf.summary.histogram('gradient/' + var.name, grad,
collections = [self.default_collection]) for grad, var in self.grads]
return self.grads
[docs]class GANBaseModel(ModelDes):
""" Base model for GANs """
def __init__(self, input_vec_length, learning_rate):
self.input_vec_length = input_vec_length
assert len(learning_rate) == 2
self.dis_learning_rate, self.gen_learning_rate = learning_rate
@property
def g_collection(self):
return 'default_g'
@property
def d_collection(self):
return 'default_d'
[docs] def get_random_vec_placeholder(self):
try:
return self.Z
except AttributeError:
self.Z = tf.placeholder(tf.float32, [None, self.input_vec_length])
return self.Z
def _get_prediction_placeholder(self):
return self.get_random_vec_placeholder()
[docs] def get_graph_feed(self):
default_feed = self._get_graph_feed()
random_input_feed = self._get_random_input_feed()
default_feed.update(random_input_feed)
return default_feed
def _get_random_input_feed(self):
feed = {self.get_random_vec_placeholder():
np.random.normal(size = (self.get_batch_size(),
self.input_vec_length))}
return feed
def _create_model(self):
# TODO
real_data = self.get_train_placeholder()[0]
with tf.variable_scope('generator') as scope:
self.gen_data = self._generator()
scope.reuse_variables()
self.sample_gen_data = self._generator(train = False)
with tf.variable_scope('discriminator') as scope:
self.d_real = self._discriminator(real_data)
scope.reuse_variables()
self.d_fake = self._discriminator(self.gen_data)
with tf.name_scope('discriminator_out'):
tf.summary.histogram('discrim_real',
tf.nn.sigmoid(self.d_real),
collections = [self.d_collection])
tf.summary.histogram('discrim_gen',
tf.nn.sigmoid(self.d_fake),
collections = [self.d_collection])
[docs] def get_gen_data(self):
return self.gen_data
[docs] def get_sample_gen_data(self):
return self.sample_gen_data
[docs] def def_loss(self, dis_loss_fnc, gen_loss_fnc):
""" updata definintion of loss functions """
self.d_loss = dis_loss_fnc(self.d_real, self.d_fake, name='d_loss')
self.g_loss = gen_loss_fnc(self.d_fake, name='g_loss')
[docs] def get_discriminator_optimizer(self):
try:
return self.d_optimizer
except AttributeError:
self.d_optimizer = self._get_discriminator_optimizer()
return self.d_optimizer
[docs] def get_generator_optimizer(self):
try:
return self.g_optimizer
except AttributeError:
self.g_optimizer = self._get_generator_optimizer()
return self.g_optimizer
def _get_discriminator_optimizer(self):
# TODO use for future
self.d_optimizer = tf.train.AdamOptimizer(beta1=0.5,
learning_rate=self.dis_learning_rate)
return self.d_optimizer
def _get_generator_optimizer(self):
# TODO use for future
self.g_optimizer = tf.train.AdamOptimizer(beta1=0.5,
learning_rate=self.gen_learning_rate)
return self.g_optimizer
[docs] def get_discriminator_loss(self):
try:
return self.d_loss
except AttributeError:
self.d_loss = self._get_discriminator_loss()
tf.summary.scalar('d_loss_summary', self.d_loss,
collections=[self.d_collection])
return self.d_loss
[docs] def get_generator_loss(self):
try:
return self.g_loss
except AttributeError:
self.g_loss = self._get_generator_loss()
tf.summary.scalar('g_loss_summary', self.g_loss,
collections=[self.g_collection])
return self.g_loss
def _get_discriminator_loss(self):
return GAN_discriminator_loss(self.d_real, self.d_fake,
name='d_loss')
def _get_generator_loss(self):
return GAN_generator_loss(self.d_fake, name='g_loss')
[docs] def get_discriminator_grads(self):
try:
return self.d_grads
except AttributeError:
d_training_vars = [v for v in tf.trainable_variables()
if v.name.startswith('discriminator/')]
optimizer = self.get_discriminator_optimizer()
loss = self.get_discriminator_loss()
self.d_grads = optimizer.compute_gradients(loss,
var_list=d_training_vars)
[tf.summary.histogram('d_gradient/' + var.name, grad,
collections=[self.d_collection])
for grad, var in self.d_grads]
return self.d_grads
[docs] def get_generator_grads(self):
try:
return self.g_grads
except AttributeError:
g_training_vars = [v for v in tf.trainable_variables()
if v.name.startswith('generator/')]
optimizer = self.get_generator_optimizer()
loss = self.get_generator_loss()
self.g_grads = optimizer.compute_gradients(loss,
var_list=g_training_vars)
[tf.summary.histogram('g_gradient/' + var.name, grad,
collections=[self.g_collection])
for grad, var in self.g_grads]
return self.g_grads