75 lines
2.6 KiB
Python
75 lines
2.6 KiB
Python
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
#
|
||
|
# ==============================================================================
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
from operator import mul
|
||
|
from functools import reduce
|
||
|
|
||
|
class VarLengthCollate:
|
||
|
def __init__(self, ignore_symbol=0):
|
||
|
self.ignore_symbol = ignore_symbol
|
||
|
|
||
|
def _measure_array_max_dim(self, batch):
|
||
|
s=list(batch[0].size())
|
||
|
different=[False] * len(s)
|
||
|
for i in range(1, len(batch)):
|
||
|
ns = batch[i].size()
|
||
|
different = [different[j] or s[j]!=ns[j] for j in range(len(s))]
|
||
|
s=[max(s[j], ns[j]) for j in range(len(s))]
|
||
|
return s, different
|
||
|
|
||
|
def _merge_var_len_array(self, batch):
|
||
|
max_size, different = self._measure_array_max_dim(batch)
|
||
|
s=[len(batch)] + max_size
|
||
|
storage = batch[0].storage()._new_shared(reduce(mul, s, 1))
|
||
|
out = batch[0].new(storage).view(s).fill_(self.ignore_symbol)
|
||
|
for i, d in enumerate(batch):
|
||
|
this_o = out[i]
|
||
|
for j, diff in enumerate(different):
|
||
|
if different[j]:
|
||
|
this_o = this_o.narrow(j,0,d.size(j))
|
||
|
this_o.copy_(d)
|
||
|
return out
|
||
|
|
||
|
|
||
|
def __call__(self, batch):
|
||
|
if isinstance(batch[0], dict):
|
||
|
return {k: self([b[k] for b in batch]) for k in batch[0].keys()}
|
||
|
elif isinstance(batch[0], np.ndarray):
|
||
|
return self([torch.from_numpy(a) for a in batch])
|
||
|
elif torch.is_tensor(batch[0]):
|
||
|
return self._merge_var_len_array(batch)
|
||
|
else:
|
||
|
assert False, "Unknown type: %s" % type(batch[0])
|
||
|
|
||
|
class MetaCollate:
|
||
|
def __init__(self, meta_name="meta", collate=VarLengthCollate()):
|
||
|
self.meta_name = meta_name
|
||
|
self.collate = collate
|
||
|
|
||
|
def __call__(self, batch):
|
||
|
if isinstance(batch[0], dict):
|
||
|
meta = [b[self.meta_name] for b in batch]
|
||
|
batch = [{k: v for k,v in b.items() if k!=self.meta_name} for b in batch]
|
||
|
else:
|
||
|
meta = None
|
||
|
|
||
|
res = self.collate(batch)
|
||
|
if meta is not None:
|
||
|
res[self.meta_name] = meta
|
||
|
|
||
|
return res
|