diff --git a/.gitignore b/.gitignore index 1a72571c3d478ddf1eb4122fb731f6c722da5baa..82a5e7396391268d4e9577a1c7fa2318e0efcc73 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,8 @@ pairing/ .cache .coverage .vscode +<<<<<<< HEAD data/ +======= +.idea/ +>>>>>>> two_way_derivatives diff --git a/database/database_table.py b/database/database_table.py index 8025b136f5fa834e0f89b90d3a90a2a6fd213665..b785db32c1ec698bebf9aa7899bb17de0068cab9 100644 --- a/database/database_table.py +++ b/database/database_table.py @@ -149,6 +149,8 @@ class DatabaseTable(Table): self._protocol = None if not hasattr(self, '_definitions'): self._definitions = None + if not hasattr(self, '_derivatives'): + self._derivatives = {} if 'protocol' in kwargs.keys(): self.load_protocol(kwargs['protocol']) @@ -781,69 +783,65 @@ class DatabaseTable(Table): have been resolved and updated, to ensure dependencies will not be ignored. ''' if self._protocol is None: - return {'original': original, 'dbcolumn': original, 'new': original, 'level': 0} + return {'query': original, 'dbcolumn': original, 'level': 0} target = self._get_variable_target(original, year) - if target in self._derivatives: # This variable has been evaluated already, just return return self._derivatives[target] if target is not None and str(self) + '.' + target in recursion_list: # This is a circular reference. Don't be like that. - print(target) + print(target, self) raise CircularReferenceError(target) - original = self._protocol.original_from_target(target, year) or original - try: - dbcolumn = self._protocol.dbcolumn_from_target(target) - except InvalidTargetError: - dbcolumn = None + # Query for the column, header of csv or var name if empty + case = self._protocol.original_from_target(target, year) or original - if is_aggregation(original): + if not case.startswith('~'): + # Possible header, not a query to be evaluated here + return {'query': case, 'dbcolumn': original, 'level': 0} + + # original column array [var_name, type] + dbcolumn = self._protocol.dbcolumn_from_target(target) + + if is_aggregation(case): # Aggregation not integrated - derivative = {'original': original, 'dbcolumn': dbcolumn, 'new': original, 'level': -1} + derivative = {'query': case, 'dbcolumn': dbcolumn, 'level': -1} self._derivatives[target] = derivative return derivative - denorm_match = re.match(r'~?([a-zA-Z0-9_]+)\.([a-zA-Z0-9_]+)', original) - if denorm_match is not None: - table, column = denorm_match.groups() - table = gen_data_table(table, self.metadata) + # Well, looks like we actually got a derivative or denormalization here + recursion_list.append(str(self) + '.' + target) + level = 1 # level of dependency of this var + referred_tables = [] + + case = case.strip("~") # doesn't need "~" anymore + str_list = re.findall(r'("[\w]"|[\w.]+)', case) + for substr in str_list: + if '.' in substr: # We have a var from another table + table = substr.split('.')[0] + table = gen_data_table(table, self.metadata) + var_name = substr.split('.')[1] + level += 1 + else: + table = self + var_name = substr - if table is self: - return self._derivative_recursion(column, year, recursion_list) - derivative = table._resolv_derivative(column, year) + # If it is a var, will need to be evaluated as it's a dependency. + if table._protocol and table._protocol.target_from_dbcolumn(var_name) is not None: + # Prevents lowering level if var is an aggregation. + level = max(level + table._derivative_recursion(var_name, year, recursion_list)['level'], level) - self._derivatives[target] = {'original': original, 'dbcolumn': dbcolumn, 'level': 0, 'dbmapped': True, - 'new': '.'.join([table.name, derivative['dbcolumn'][0]])} - return self._derivatives[target] + var_target = self._get_variable_target(var_name.strip('"'), year) + if var_target is not None: + var_db = self._protocol.dbcolumn_from_target(var_target)[0] + case = case.replace(var_name, var_db) - if not original.startswith('~'): - # Possibly keyword, definitely not a variable. Shouldn't change the level. - return {'original': original, 'processed': original, 'dbcolumn': dbcolumn, 'level': 0} + if table is not self and table not in referred_tables: + referred_tables.append(table) - # Well, looks like we actually got a derivative here - original = original.strip('~ ') - str_list = re.findall(r'("[\w]+"|[\w]+)', original) - level = 0 - substitutions = [] - recursion_list.append(str(self) + '.' + target) - for substring in str_list: - derivative = self._derivative_recursion(substring.strip('"'), year, - recursion_list=recursion_list) - if derivative['dbcolumn']: - substitutions.append({'original': substring, 'new': derivative['dbcolumn'][0]}) - if derivative['level'] >= level: - level = derivative['level'] + 1 - - processed = original - dbmapped = False # column neded to execute the derivative is present on table or need a file. - for substitution in substitutions: - processed = re.sub(substitution['original'], substitution['new'], processed) - dbmapped = True - self._derivatives[target] = {'original': original, 'dbcolumn': dbcolumn, 'level': level, - 'processed': processed, 'dbmapped': dbmapped} + self._derivatives[target] = {'query': case, 'dbcolumn': dbcolumn, 'level': level, 'tables': referred_tables} return self._derivatives[target] def _resolv_derivative(self, original, year): @@ -855,30 +853,40 @@ class DatabaseTable(Table): self._derivatives = {} return self._derivative_recursion(original, year) - def _get_denormalizations(self, ttable, originals, year): - ''' - Searches protocol for denormalizations and yields the necessary update queries. - ''' - exp = r'([a-zA-Z0-9_]+)\.([a-zA-Z0-9_]+)' - external = {} - for dst, original in originals: - original = original.strip(' ~\n\t') - for match in re.finditer(exp, original): - table, column = match.groups() - if table not in external: - external[table] = [] - external[table].append([dst, text(original)]) - - for table in external: - query = update(ttable) - for dst, src in external[table]: - query = query.values(**{dst[0]: src}) - for fk_column, fkey in self.get_relations(table): - fk_column = ttable.columns.get(fk_column.name) + def _apply_denormalization(self, ttable, year, column, denorm_query, referred_tables, bind): + + # Hack to make pymonetdb be able to work with the columns from 2 tables when one is temporary + t_schema = ttable.schema + ttable.schema = None + + query = update(ttable) + + query = query.values(**{column: text(denorm_query)}) + + for ref_table in referred_tables: + ref_table.map_from_database(bind) + fk_tuples = [(ttable.columns.get(fk_column.name), fkey) + for fk_column, fkey in self.get_relations(ref_table)] + if not fk_tuples: + logger.warning("Trying to use relations from " + str(ref_table) + + " instead of " + str(self) + " to apply derivative.") + fk_tuples = [(ttable.columns.get(fkey.name),fk_column) + for fk_column, fkey in ref_table.get_relations(self)] + if not fk_tuples: + logger.error("COULDN'T ESTABLISH " + ref_table.name + " RELATION WITH " + self.name + + " IGNORING COLUMN: " + column) + ttable.schema = t_schema + return + + for fk_column, fkey in fk_tuples: query = query.where(fk_column == fkey) - if year: - query = query.where(ttable.columns.get(settings.YEAR_COLUMN) == year) - yield query + + if year: + query = query.where(ttable.columns.get(settings.YEAR_COLUMN) == year) + + bind.execute(query) + + ttable.schema = t_schema def apply_derivatives(self, ttable, columns, year, bind=None, dbonly=False): ''' @@ -889,32 +897,25 @@ class DatabaseTable(Table): if bind is None: bind = self.metadata.bind - self._derivatives = {} for original in columns: self._resolv_derivative(original, year) - originals = [(self._derivatives[d]['dbcolumn'], self._derivatives[d]['original'])\ - for d in self._derivatives if self._derivatives[d]['level'] == 0] + max_level = max([self._derivatives[d]['level'] for d in self._derivatives]) - t_schema = ttable.schema - ttable.schema = None - for query in self._get_denormalizations(ttable, originals, year): - bind.execute(query) + for i in range(1, max_level + 1): + query = {} - ttable.schema = t_schema - if len(self._derivatives) > 0: - max_level = max([self._derivatives[d]['level'] for d in self._derivatives]) - for i in range(max_level): - i = i+1 - query = {} - level = [self._derivatives[d] for d in self._derivatives if\ - self._derivatives[d]['level'] == i] - for derivative in level: - if not dbonly or derivative['dbmapped']: - query[derivative['dbcolumn'][0]] = text(derivative['processed']) + level = [self._derivatives[d] for d in self._derivatives if self._derivatives[d]['level'] == i] + for derivative in level: + if len(derivative['tables']) == 0: + if not dbonly: + query[derivative['dbcolumn'][0]] = text(derivative['query']) + else: + self._apply_denormalization(ttable, year, derivative['dbcolumn'][0], derivative['query'], + derivative['tables'], bind) + if query: query = update(ttable).values(**query) - bind.execute(query) return self._derivatives @@ -995,7 +996,8 @@ class DatabaseTable(Table): foreign_key = fk break if not foreign_key: - raise MissingForeignKeyError(table) + logger.warning("Couldn't find foreign key relation between " + self.name + " and " + table.name) + return None for _, fk_column in foreign_key.columns.items(): fkey = list(fk_column.foreign_keys)[0] fkey = fkey.column.name