dnc-with-demon/Utils/Collate.py

75 lines
2.6 KiB
Python
Raw Permalink Normal View History

2022-11-06 05:59:40 +08:00
# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import numpy as np
import torch
from operator import mul
from functools import reduce
class VarLengthCollate:
def __init__(self, ignore_symbol=0):
self.ignore_symbol = ignore_symbol
def _measure_array_max_dim(self, batch):
s=list(batch[0].size())
different=[False] * len(s)
for i in range(1, len(batch)):
ns = batch[i].size()
different = [different[j] or s[j]!=ns[j] for j in range(len(s))]
s=[max(s[j], ns[j]) for j in range(len(s))]
return s, different
def _merge_var_len_array(self, batch):
max_size, different = self._measure_array_max_dim(batch)
s=[len(batch)] + max_size
storage = batch[0].storage()._new_shared(reduce(mul, s, 1))
out = batch[0].new(storage).view(s).fill_(self.ignore_symbol)
for i, d in enumerate(batch):
this_o = out[i]
for j, diff in enumerate(different):
if different[j]:
this_o = this_o.narrow(j,0,d.size(j))
this_o.copy_(d)
return out
def __call__(self, batch):
if isinstance(batch[0], dict):
return {k: self([b[k] for b in batch]) for k in batch[0].keys()}
elif isinstance(batch[0], np.ndarray):
return self([torch.from_numpy(a) for a in batch])
elif torch.is_tensor(batch[0]):
return self._merge_var_len_array(batch)
else:
assert False, "Unknown type: %s" % type(batch[0])
class MetaCollate:
def __init__(self, meta_name="meta", collate=VarLengthCollate()):
self.meta_name = meta_name
self.collate = collate
def __call__(self, batch):
if isinstance(batch[0], dict):
meta = [b[self.meta_name] for b in batch]
batch = [{k: v for k,v in b.items() if k!=self.meta_name} for b in batch]
else:
meta = None
res = self.collate(batch)
if meta is not None:
res[self.meta_name] = meta
return res