Skip to content
gqlxj1987's Blog
Go back

Neural Machine Translation

Edit page

repo

数据预处理

原文链接

使用了新的Dataset API部分

src_dataset=tf.data.TextLineDataset('src_data.txt')  
tgt_dataset=tf.data.TextLineDataset('tgt_data.txt')  

查找表的构造方法:

def create_vocab_tables(src_vocab_file, tgt_vocab_file, share_vocab):
  """Creates vocab tables for src_vocab_file and tgt_vocab_file."""
  src_vocab_table = lookup_ops.index_table_from_file(
      src_vocab_file, default_value=UNK_ID)
  if share_vocab:
    tgt_vocab_table = src_vocab_table
  else:
    tgt_vocab_table = lookup_ops.index_table_from_file(
        tgt_vocab_file, default_value=UNK_ID)
  return src_vocab_table, tgt_vocab_table

使用了tensorflow库中定义的lookup_ops,简化了产生字典的操作

if not output_buffer_size:
    output_buffer_size = batch_size * 1000
  src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32)
  tgt_sos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(sos)), tf.int32)
  tgt_eos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(eos)), tf.int32)

# 通过zip操作将源数据集和目标数据集合并在一起
# 此时的张量变化 [src_dataset] + [tgt_dataset] ---> [src_dataset, tgt_dataset]
  src_tgt_dataset = tf.data.Dataset.zip((src_dataset, tgt_dataset))
# 数据集分片,分布式训练的时候可以分片来提高训练速度
  src_tgt_dataset = src_tgt_dataset.shard(num_shards, shard_index)
  if skip_count is not None:
    src_tgt_dataset = src_tgt_dataset.skip(skip_count)
# 随机打乱数据,切断相邻数据之间的联系
# 根据文档,该步骤要尽早完成,完成该步骤之后在进行其他的数据集操作
  src_tgt_dataset = src_tgt_dataset.shuffle(
      output_buffer_size, random_seed, reshuffle_each_iteration)
    
  # 将每一行数据,根据“空格”切分开来
  # 这个步骤可以并发处理,用num_parallel_calls指定并发量
  # 通过prefetch来预获取一定数据到缓冲区,提升数据吞吐能力
  # 张量变化举例 ['上海 浦东', '上海 浦东'] ---> [['上海', '浦东'], ['上海', '浦东']]
  src_tgt_dataset = src_tgt_dataset.map(
      lambda src, tgt: (
          tf.string_split([src]).values, tf.string_split([tgt]).values),
      num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)

  # Filter zero length input sequences.
  src_tgt_dataset = src_tgt_dataset.filter(
      lambda src, tgt: tf.logical_and(tf.size(src) > 0, tf.size(tgt) > 0))
# 限制源数据最大长度
  if src_max_len:
    src_tgt_dataset = src_tgt_dataset.map(
        lambda src, tgt: (src[:src_max_len], tgt),
        num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)
  # 限制目标数据的最大长度
  if tgt_max_len:
    src_tgt_dataset = src_tgt_dataset.map(
        lambda src, tgt: (src, tgt[:tgt_max_len]),
        num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)
  # Convert the word strings to ids.  Word strings that are not in the
  # vocab get the lookup table's default_value integer.
  # 通过map操作将字符串转换为数字
  # 张量变化举例 [['上海', '浦东'], ['上海', '浦东']] ---> [[1, 2], [1, 2]]
  src_tgt_dataset = src_tgt_dataset.map(
      lambda src, tgt: (tf.cast(src_vocab_table.lookup(src), tf.int32),
                        tf.cast(tgt_vocab_table.lookup(tgt), tf.int32)),
      num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)
  # Create a tgt_input prefixed with <sos> and a tgt_output suffixed with <eos>.
  
  # 给目标数据加上 sos, eos 标记
  # 张量变化举例 [[1, 2], [1, 2]] ---> [[1, 2], [sos_id, 1, 2], [1, 2, eos_id]]
  src_tgt_dataset = src_tgt_dataset.map(
      lambda src, tgt: (src,
                        tf.concat(([tgt_sos_id], tgt), 0),
                        tf.concat((tgt, [tgt_eos_id]), 0)),
      num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)
  # Add in sequence lengths.
  # 增加长度信息
  # 张量变化举例 [[1, 2], [sos_id, 1, 2], [1, 2, eos_id]] ---> [[1, 2], [sos_id, 1, 2], [1, 2, eos_id], [src_size], [tgt_size]]
  src_tgt_dataset = src_tgt_dataset.map(
      lambda src, tgt_in, tgt_out: (
          src, tgt_in, tgt_out, tf.size(src), tf.size(tgt_in)),
      num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)

处理过程分析:

  1. 开始标记和结束标记,转换成为int32

  2. 关于增加sos以及eos标记,为啥src和target添加的标记不同?

  3. 关于增加长度信息的意义?

# 数据对齐
# 参数x实际上就是我们的 dataset 对象
def batching_func(x):
    # 调用dataset的padded_batch方法,对齐的同时,也对数据集进行分批
    return x.padded_batch(
        batch_size,
        # 对齐数据的形状
        padded_shapes=(
            # 因为数据长度不定,因此设置None
            tf.TensorShape([None]),  # src
            # 因为数据长度不定,因此设置None
            tf.TensorShape([None]),  # tgt_input
            # 因为数据长度不定,因此设置None
            tf.TensorShape([None]),  # tgt_output
            # 数据长度张量,实际上不需要对齐
            tf.TensorShape([]),  # src_len
            tf.TensorShape([])),  # tgt_len
        # 对齐数据的值
        padding_values=(
            # 用src_eos_id填充到 src 的末尾
            src_eos_id,  # src
            # 用tgt_eos_id填充到 tgt_input 的末尾
            tgt_eos_id,  # tgt_input
            # 用tgt_eos_id填充到 tgt_output 的末尾
            tgt_eos_id,  # tgt_output
            0,  # src_len -- unused
            0))  # tgt_len -- unused

这个数据对齐,没看懂。。

 if num_buckets > 1:

    def key_func(unused_1, unused_2, unused_3, src_len, tgt_len):
      # Calculate bucket_width by maximum source sequence length.
      # Pairs with length [0, bucket_width) go to bucket 0, length
      # [bucket_width, 2 * bucket_width) go to bucket 1, etc.  Pairs with length
      # over ((num_bucket-1) * bucket_width) words all go into the last bucket.
      if src_max_len:
        bucket_width = (src_max_len + num_buckets - 1) // num_buckets
      else:
        bucket_width = 10

      # Bucket sentence pairs by the length of their source sentence and target
      # sentence.
      bucket_id = tf.maximum(src_len // bucket_width, tgt_len // bucket_width)
      return tf.to_int64(tf.minimum(num_buckets, bucket_id))

    def reduce_func(unused_key, windowed_data):
      return batching_func(windowed_data)

    batched_dataset = src_tgt_dataset.apply(
        tf.contrib.data.group_by_window(
            key_func=key_func, reduce_func=reduce_func, window_size=batch_size))
  else:
    batched_dataset = batching_func(src_tgt_dataset)

关于分桶操作,得到的结果,就是相似长度的数据放在一起,能够提升计算效率!!

使用迭代器获取处理之后的数据

 batched_iter = batched_dataset.make_initializable_iterator()
  (src_ids, tgt_input_ids, tgt_output_ids, src_seq_len,
   tgt_seq_len) = (batched_iter.get_next())

Edit page
Share this post on:

Previous Post
bolt-内嵌kv存储
Next Post
Improve Spell Correct