Files
startrek-ml/my_txtutils.py
2018-05-26 00:29:32 +02:00

335 lines
12 KiB
Python

# encoding: UTF-8
# Copyright 2017 Google.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import glob
import sys
import codecs
# size of the alphabet that we work with
ALPHASIZE = 98
# Specification of the supported alphabet (subset of ASCII-7)
# 10 line feed LF
# 32-64 numbers and punctuation
# 65-90 upper-case letters
# 91-97 more punctuation
# 97-122 lower-case letters
# 123-126 more punctuation
def convert_from_alphabet(a):
"""Encode a character
:param a: one character
:return: the encoded value
"""
if a == 9:
return 1
if a == 10:
return 127 - 30 # LF
elif 32 <= a <= 126:
return a - 30
else:
return 0 # unknown
# encoded values:
# unknown = 0
# tab = 1
# space = 2
# all chars from 32 to 126 = c-30
# LF mapped to 127-30
def convert_to_alphabet(c, avoid_tab_and_lf=False):
"""Decode a code point
:param c: code point
:param avoid_tab_and_lf: if True, tab and line feed characters are replaced by '\'
:return: decoded character
"""
if c == 1:
return 32 if avoid_tab_and_lf else 9 # space instead of TAB
if c == 127 - 30:
return 92 if avoid_tab_and_lf else 10 # \ instead of LF
if 32 <= c + 30 <= 126:
return c + 30
else:
return 0 # unknown
def encode_text(s):
"""Encode a string.
:param s: a text string
:return: encoded list of code points
"""
return list(map(lambda a: convert_from_alphabet(ord(a)), s))
def decode_to_text(c, avoid_tab_and_lf=False):
"""Decode an encoded string.
:param c: encoded list of code points
:param avoid_tab_and_lf: if True, tab and line feed characters are replaced by '\'
:return:
"""
return "".join(map(lambda a: chr(convert_to_alphabet(a, avoid_tab_and_lf)), c))
def sample_from_probabilities(probabilities, topn=ALPHASIZE):
"""Roll the dice to produce a random integer in the [0..ALPHASIZE] range,
according to the provided probabilities. If topn is specified, only the
topn highest probabilities are taken into account.
:param probabilities: a list of size ALPHASIZE with individual probabilities
:param topn: the number of highest probabilities to consider. Defaults to all of them.
:return: a random integer
"""
p = np.squeeze(probabilities)
p[np.argsort(p)[:-topn]] = 0
p = p / np.sum(p)
return np.random.choice(ALPHASIZE, 1, p=p)[0]
def rnn_minibatch_sequencer(raw_data, batch_size, sequence_size, nb_epochs):
"""
Divides the data into batches of sequences so that all the sequences in one batch
continue in the next batch. This is a generator that will keep returning batches
until the input data has been seen nb_epochs times. Sequences are continued even
between epochs, apart from one, the one corresponding to the end of raw_data.
The remainder at the end of raw_data that does not fit in an full batch is ignored.
:param raw_data: the training text
:param batch_size: the size of a training minibatch
:param sequence_size: the unroll size of the RNN
:param nb_epochs: number of epochs to train on
:return:
x: one batch of training sequences
y: on batch of target sequences, i.e. training sequences shifted by 1
epoch: the current epoch number (starting at 0)
"""
data = np.array(raw_data)
data_len = data.shape[0]
# using (data_len-1) because we must provide for the sequence shifted by 1 too
nb_batches = (data_len - 1) // (batch_size * sequence_size)
assert nb_batches > 0, "Not enough data, even for a single batch. Try using a smaller batch_size."
rounded_data_len = nb_batches * batch_size * sequence_size
xdata = np.reshape(data[0:rounded_data_len], [batch_size, nb_batches * sequence_size])
ydata = np.reshape(data[1:rounded_data_len + 1], [batch_size, nb_batches * sequence_size])
for epoch in range(nb_epochs):
for batch in range(nb_batches):
x = xdata[:, batch * sequence_size:(batch + 1) * sequence_size]
y = ydata[:, batch * sequence_size:(batch + 1) * sequence_size]
x = np.roll(x, -epoch, axis=0) # to continue the text from epoch to epoch (do not reset rnn state!)
y = np.roll(y, -epoch, axis=0)
yield x, y, epoch
def find_book(index, bookranges):
return next(
book["name"] for book in bookranges if (book["start"] <= index < book["end"]))
def find_book_index(index, bookranges):
return next(
i for i, book in enumerate(bookranges) if (book["start"] <= index < book["end"]))
def print_learning_learned_comparison(X, Y, losses, bookranges, batch_loss, batch_accuracy, epoch_size, index, epoch):
"""Display utility for printing learning statistics"""
print()
# epoch_size in number of batches
batch_size = X.shape[0] # batch_size in number of sequences
sequence_len = X.shape[1] # sequence_len in number of characters
start_index_in_epoch = index % (epoch_size * batch_size * sequence_len)
for k in range(batch_size):
index_in_epoch = index % (epoch_size * batch_size * sequence_len)
decx = decode_to_text(X[k], avoid_tab_and_lf=True)
decy = decode_to_text(Y[k], avoid_tab_and_lf=True)
bookname = find_book(index_in_epoch, bookranges)
formatted_bookname = "{: <10.40}".format(bookname) # min 10 and max 40 chars
epoch_string = "{:4d}".format(index) + " (epoch {}) ".format(epoch)
loss_string = "loss: {:.5f}".format(losses[k])
print_string = epoch_string + formatted_bookname + "{}{}{}"
print(print_string.format(decx, decy, loss_string))
index += sequence_len
# box formatting characters:
# │ \u2502
# ─ \u2500
# └ \u2514
# ┘ \u2518
# ┴ \u2534
# ┌ \u250C
# ┐ \u2510
format_string = "{:─^" + str(len(epoch_string)) + "}"
format_string += "{:─^" + str(len(formatted_bookname)) + "}"
format_string += "{:─^" + str(len(decx) + 2) + "}"
format_string += "{:─^" + str(len(decy) + 2) + "}"
format_string += "{:─^" + str(len(loss_string)) + "}┘"
footer = format_string.format('INDEX', 'BOOK NAME', 'TRAINING SEQUENCE', 'PREDICTED SEQUENCE', 'LOSS')
print(footer)
# print statistics
batch_index = start_index_in_epoch // (batch_size * sequence_len)
batch_string = "batch {}/{} in epoch {},".format(batch_index, epoch_size, epoch)
stats = "{: <28} batch loss: {:.5f}, batch accuracy: {:.5f}".format(batch_string, batch_loss, batch_accuracy)
print()
print("TRAINING STATS: {}".format(stats))
class Progress:
"""Text mode progress bar.
Usage:
p = Progress(30)
p.step()
p.step()
p.step(start=True) # to restart form 0%
The progress bar displays a new header at each restart."""
def __init__(self, maxi, size=100, msg=""):
"""
:param maxi: the number of steps required to reach 100%
:param size: the number of characters taken on the screen by the progress bar
:param msg: the message displayed in the header of the progress bat
"""
self.maxi = maxi
self.p = self.__start_progress(maxi)() # () to get the iterator from the generator
self.header_printed = False
self.msg = msg
self.size = size
def step(self, reset=False):
if reset:
self.__init__(self.maxi, self.size, self.msg)
if not self.header_printed:
self.__print_header()
next(self.p)
def __print_header(self):
print()
format_string = "0%{: ^" + str(self.size - 6) + "}100%"
print(format_string.format(self.msg))
self.header_printed = True
def __start_progress(self, maxi):
def print_progress():
# Bresenham's algorithm. Yields the number of dots printed.
# This will always print 100 dots in max invocations.
dx = maxi
dy = self.size
d = dy - dx
for x in range(maxi):
k = 0
while d >= 0:
print('=', end="", flush=True)
k += 1
d -= dx
d += dy
yield k
return print_progress
def read_data_files(directory, validation=True):
"""Read data files according to the specified glob pattern
Optionnaly set aside the last file as validation data.
No validation data is returned if there are 5 files or less.
:param directory: for example "data/*.txt"
:param validation: if True (default), sets the last file aside as validation data
:return: training data, validation data, list of loaded file names with ranges
If validation is
"""
codetext = []
bookranges = []
shakelist = glob.glob(directory, recursive=True)
for shakefile in shakelist:
shaketext = codecs.open(shakefile, "r",encoding='utf-8', errors='ignore')
print("Loading file " + shakefile)
start = len(codetext)
codetext.extend(encode_text(shaketext.read()))
end = len(codetext)
bookranges.append({"start": start, "end": end, "name": shakefile.rsplit("/", 1)[-1]})
shaketext.close()
if len(bookranges) == 0:
sys.exit("No training data has been found. Aborting.")
# For validation, use roughly 90K of text,
# but no more than 10% of the entire text
# and no more than 1 book in 5 => no validation at all for 5 files or fewer.
# 10% of the text is how many files ?
total_len = len(codetext)
validation_len = 0
nb_books1 = 0
for book in reversed(bookranges):
validation_len += book["end"]-book["start"]
nb_books1 += 1
if validation_len > total_len // 10:
break
# 90K of text is how many books ?
validation_len = 0
nb_books2 = 0
for book in reversed(bookranges):
validation_len += book["end"]-book["start"]
nb_books2 += 1
if validation_len > 90*1024:
break
# 20% of the books is how many books ?
nb_books3 = len(bookranges) // 5
# pick the smallest
nb_books = min(nb_books1, nb_books2, nb_books3)
if nb_books == 0 or not validation:
cutoff = len(codetext)
else:
cutoff = bookranges[-nb_books]["start"]
valitext = codetext[cutoff:]
codetext = codetext[:cutoff]
return codetext, valitext, bookranges
def print_data_stats(datalen, valilen, epoch_size):
datalen_mb = datalen/1024.0/1024.0
valilen_kb = valilen/1024.0
print("Training text size is {:.2f}MB with {:.2f}KB set aside for validation.".format(datalen_mb, valilen_kb)
+ " There will be {} batches per epoch".format(epoch_size))
def print_validation_header(validation_start, bookranges):
bookindex = find_book_index(validation_start, bookranges)
books = ''
for i in range(bookindex, len(bookranges)):
books += bookranges[i]["name"]
if i < len(bookranges)-1:
books += ", "
print("{: <60}".format("Validating on " + books), flush=True)
def print_validation_stats(loss, accuracy):
print("VALIDATION STATS: loss: {:.5f}, accuracy: {:.5f}".format(loss,
accuracy))
def print_text_generation_header():
print()
print("{:─^111}".format('Generating random text from learned state'))
def print_text_generation_footer():
print()
print("{:─^111}".format('End of generation'))
def frequency_limiter(n, multiple=1, modulo=0):
def limit(i):
return i % (multiple * n) == modulo*multiple
return limit