from best.data_iterators import iter_best_files, iter_best_old_files
from best.model import model_handler
from best.test_ids import test_ids, valid_ids
from best.custom_logging import logging
from best.belief_model import belief_model_handler
from best.results_to_file import write_results
import matplotlib
from matplotlib import pyplot

def target_specific_dist(doc, prefix, corpus_type):
    stats = {'pos_auth': 0 , 'neg_auth': 0, 'pos': 0, 'neg': 0}
    pos = 0
    neg = 0
    for sentiment in doc.evaluator_best.sentiments:
        if sentiment.target.mention_id[0:2] == prefix and sentiment.polarity != 'none':
            #if prefix == 'em':
                #print(doc.doc_id, sentiment.target.mention_id, sentiment.source)
            stats[sentiment.polarity] += 1
            if sentiment.polarity == 'pos':
                pos += 1
            elif sentiment.polarity == 'neg':
                neg += 1
            else:
                raise ValueError
            src = sentiment.source
            if str(src) == 'None' and corpus_type == 'nw':
                stats[sentiment.polarity + '_auth'] += 1
            if corpus_type == 'df' and str(src) != 'None':
                if str(src.nom_head) == 'None':
                    stats[sentiment.polarity + '_auth'] += 1

    return stats

def sentiment_distribution(docs, corpus_type):
    event_stats = {'pos_auth': 0 , 'neg_auth': 0, 'pos': 0, 'neg': 0}
    relation_stats = {'pos_auth': 0 , 'neg_auth': 0, 'pos': 0, 'neg': 0}
    entity_stats = {'pos_auth': 0 , 'neg_auth': 0, 'pos': 0, 'neg': 0}
    no_ent_stats = {'pos_auth': 0 , 'neg_auth': 0, 'pos': 0, 'neg': 0}
    total_stats = {'pos_auth': 0 , 'neg_auth': 0, 'pos': 0, 'neg': 0}
    
    candidate_ev, candidate_rel = 0, 0

    for doc in docs:
        relation_data = target_specific_dist(doc, 're', corpus_type)
        event_data = target_specific_dist(doc, 'em', corpus_type)
        entity_data = target_specific_dist(doc, 'm-', corpus_type)
        for key in total_stats:
            event_stats[key] += event_data[key]
            relation_stats[key] += relation_data[key]
            entity_stats[key] += entity_data[key]
            no_ent_stats[key] = event_stats[key] + relation_stats[key]
            total_stats[key] = event_stats[key] + relation_stats[key] + entity_stats[key]
    print("AGGREGATE for {}:".format(corpus_type))
    print("Event Stats: {}".format(event_stats))
    print("Relation Stats: {}".format(relation_stats))
    print("Entity Stats: {}".format(entity_stats))
    print("Total Stats without entities: {}".format(no_ent_stats))
    print("Total Stats: {}".format(total_stats))



    rel_total = sum([len(doc.pairs_relation) for doc in docs])
    rel_pos = relation_stats['pos'] + relation_stats['neg'] 
    event_total = sum([len(doc.pairs_event) for doc in docs])
    event_pos = event_stats['pos'] + event_stats['neg']



    print("Seperator ***************************************************")
    print("Type: {} Number of Docs: {} Number of Relation Pairs {} Number of Positive Relations {} Number of Event Pairs {} Number of Positive Events {}".format(corpus_type, len(docs), rel_total, rel_pos, event_total, event_pos))
    n = len(docs)
    print("Average stats for Type: {} Number of Docs: {} Number of Relation Pairs {} Number of Positive Relations {} Number of Event Pairs {} Number of Positive Events {}".format(corpus_type, len(docs), rel_total / n, rel_pos / n, event_total / n, event_pos / n))
    print("Seperator ***************************************************")
    return [event_stats, relation_stats, entity_stats, no_ent_stats, total_stats]

def avg_word_count(docs):
    count = 0
    for doc in docs:
        for sent in doc.tokenized['sentences']:
            for tok in sent['tokens']:
                count += 1
    return count, len(docs), count / len(docs)


def partition(docs):
    nw, df = [], []
    for doc in docs:
        if doc.doc_id[4:6] == 'DF':
            df.append(doc)
        else:
            nw.append(doc) 
    return nw, df, docs  

def mention_statistics(docs):
    nw, df, docs = partition(docs)
    df_stats = sentiment_distribution(df, 'df')
    nw_stats = sentiment_distribution(nw, 'nw')
    [event_stats, relation_stats, entity_stats, no_ent_stats, total_stats] = [{key: df_stat[key] + nw_stat[key] for key in df_stat} for df_stat, nw_stat in zip(df_stats, nw_stats)]
    print("AGGREGATE for {}:".format('all documents'))
    print("Event Stats: {}".format(event_stats))
    print("Relation Stats: {}".format(relation_stats))
    print("Entity Stats: {}".format(entity_stats))
    print("Total Stats without entities: {}".format(no_ent_stats))
    print("Total Stats: {}".format(total_stats))
    print("Seperator ***************************************************")
    numerator = relation_stats['pos'] + relation_stats['neg'] + event_stats['pos'] + event_stats['neg']
    denominator = sum([len(doc.pairs_event) + len(doc.pairs_relation) for doc in docs])
    print("Total percent positive examples: {} numerator: {} denominator: {}".format(numerator / denominator, numerator, denominator))
    df_wc, df_len, df_avg_wc = avg_word_count(df)
    print("For Discussion Forums, word_count: {} #docs: {}, avg_word_count: {}".format(df_wc, df_len, df_avg_wc))
    nw_wc, nw_len, nw_avg_wc = avg_word_count(nw)
    print("For Newswire, word_count: {} #docs: {}, avg_word_count: {}".format(nw_wc, nw_len, nw_avg_wc))
    print("Seperator ***************************************************")


    pos_ent_examples = entity_stats['pos'] + entity_stats['neg']
    ent_examples = sum([len(doc.pairs_entity) for doc in docs])
    print("Entity information corpus wide - Pos Examples: {} Examples: {} Percent {}".format(pos_ent_examples, ent_examples, pos_ent_examples / ent_examples))
    exit()
    
    entitym, entity, relm, evm = 0,0,0,0
    entityms, entitys, relms, evms = {},{},{},{}
    b = 0
    for doc in docs:
        id = str(doc.doc_id)
        entityms.update({str(entity_mention) + " " + id:0 for entity_mention in  doc.evaluator_ere.entity_mentions})
        relms.update({str(relation_mention) + " " + id:0 for relation_mention in  doc.evaluator_ere.relation_mentions})
        evms.update({str(event_mention) + " " + id:0 for event_mention in  doc.evaluator_ere.event_mentions})
        entitys.update({str(entity)  + " " + id:0 for entity in  doc.evaluator_ere.entities})
        entitys['None' + " " + id] = 0
        b += len(doc.evaluator_best.beliefs)
        for belief in doc.evaluator_best.beliefs:
            if(str(belief.belief_type) != 'na'):
                target = str(belief.target.mention_id) + " " + id
                if target[0] == 'm':
                    entityms[target] += 1
                if target[0] == 'r':
                    relms[target] += 1
                if target[0] == 'e':
                    evms[target] += 1
                source_m = str(belief.source)
                if source_m == "None":
                    source = source_m
                else:
                    source_m = str(belief.source).split()[1][:-1]
                    source = str(doc.evaluator_ere.entity_mentions[source_m].entity).split()[1][:-1] 
                source = source + " " + id 
                entitys[source] += 1
    hits, hits_m, hits_r, hits_e = {},{},{},{}
    for y in entityms:
        x = entityms[y]
        hits[x] = hits.get(x,0) + 1
        hits_m[x] = hits_m.get(x,0) + 1
    for y in relms:
        x = relms[y]
        hits[x] = hits.get(x,0) + 1
        hits_r[x] = hits_r.get(x,0) + 1
    for y in evms:
        x = evms[y]
        hits[x] = hits.get(x,0) + 1
        hits_e[x] = hits_e.get(x,0) + 1
    print("Beliefs per mention: mentions", hits)
    print("Same thing for entities", hits_m)
    print("Same thing for relations", hits_r)
    print("Same thing for events", hits_e)
    m = 0
    b = 0
    for val in hits:
        count = hits[val]
        m += count
        b += (count*val)
    print("Beliefs: {0}  Mentions: {1} Beliefs/Mentions: {2:.3f}".format(b, m, b/m))
    print("Beliefs: {0} (Non-entity) Mentions: {1} Beliefs/(Non-entity)Mentions: {2:.3f}".format(b, m-len(entityms), b/(m-len(entityms))))
    entitym, entity, relm, evm = 0,0,0,0
    entityms, entitys, relms, evms = {},{},{},{}
    s = 0
    for doc in docs:
        id = str(doc.doc_id)
        entityms.update({str(entity_mention) + " " + id:0 for entity_mention in  doc.evaluator_ere.entity_mentions})
        relms.update({str(relation_mention) + " " + id:0 for relation_mention in  doc.evaluator_ere.relation_mentions})
        evms.update({str(event_mention) + " " + id:0 for event_mention in  doc.evaluator_ere.event_mentions})
        entitys.update({str(entity)  + " " + id:0 for entity in  doc.evaluator_ere.entities})
        entitys['None' + " " + id] = 0
        s += len(doc.evaluator_best.sentiments)
        for sentiment in doc.evaluator_best.sentiments:
            target = str(sentiment.target.mention_id) + " " + id
            if str(sentiment.polarity) != "none":
                if target[0] == 'm':
                    entityms[target] += 1
                if target[0] == 'r':
                    relms[target] += 1
                if target[0] == 'e':
                    evms[target] += 1
                source_m = str(sentiment.source)
                if source_m == "None":
                    source = source_m
                else:
                    source_m = str(sentiment.source).split()[1][:-1]
                    source = str(doc.evaluator_ere.entity_mentions[source_m].entity).split()[1][:-1] 
                source = source + " " + id 
                entitys[source] += 1

    hits, hits_m, hits_r, hits_e = {},{},{},{}
    for y in entityms:
        x = entityms[y]
        hits[x] = hits.get(x,0) + 1
        hits_m[x] = hits_m.get(x,0) + 1
    for y in relms:
        x = relms[y]
        hits[x] = hits.get(x,0) + 1
        hits_r[x] = hits_r.get(x,0) + 1
    for y in evms:
        x = evms[y]
        hits[x] = hits.get(x,0) + 1
        hits_e[x] = hits_e.get(x,0) + 1
    print("Sentiments per mention: mentions", hits)
    print("Same thing for entities", hits_m)
    print("Same thing for relations", hits_r)
    print("Same thing for events", hits_e)
    m = 0
    s = 0
    for val in hits:
        count = hits[val]
        m += count
        s += (count*val)
    print("Sentiments: {0} mentions: {1} Sentiments/Mentions: {2:.3f}".format(s, m, s/m))
    exit()
