# nbi:hide_in
from music21 import *
from music21 import interval
import os
from collections import defaultdict
from pprint import *
import pickle
import numpy as np
import copy
import random
from ipywidgets import interact
# nbi:hide_in
def load_all_scores_from_folder(relative_path_name: str, limit: int = 1e9):
    # relative_path_name: path name to folder you want to load in, e.g. 'Datasets/chord_progs'
    # limit: if you don't want to load all the midi's just load the first 'limit' that are encountered
    # returns dict from file name to score
    
    scores = {}
    count = 0
    for subdir, dirs, files in os.walk(relative_path_name):
        for file in files:
            if not file.endswith('.mid'):
                continue
            path = os.path.join(subdir, file)
            try:
                score = converter.parse(path)
                scores[path] = score
                count += 1
            except IndexError:
                print('failed to load a midi')
                
            if count >= limit:
                print('done, loaded', count, 'midi files')
                return scores
    print('done, loaded', count, 'midi files')
    return scores
# nbi:hide_in
def filter_for_beats(beats_dict):
    #removes all drum fills (b/c those are short and won't work well for our purposes)
    new_dict = {}
    for k,v in beats_dict.items():
        if 'fill' in k:
            continue
        fl = False
        for x in v.flat.getElementsByOffset(0):
            if (isinstance(x, note.Note) and note.Note('C2') == x) or (isinstance(x, chord.Chord) and note.Note('C2') in x.notes):
                fl = True
                break
        if not fl:
            continue
        new_dict[k] = v
            
            
            
    return new_dict
# nbi:hide_in
def fix_scores(sc):
    sc = sc.makeMeasures()
    
    for x in sc.semiFlat.getElementsByClass("Measure"):
        remaining = x.barDuration.quarterLength - x.duration.quarterLength
        if (remaining > 0):
            x.insert(x.duration.quarterLength, note.Rest(quarterLength = remaining))
    
    return sc
def fix_all_scores(dicto):
    new_dict = {}
    for k,v in dicto.items():
        new_dict[k] = fix_scores(v)
    return new_dict
# nbi:hide_in
def init_stuff():
    chords_example = load_all_scores_from_folder('widgets/chord_progs')
    chords_example = fix_all_scores(chords_example)

    #transpose all chords
    majs = 0
    mins = 0
    for path, score in chords_example.items():
        #part 1: transpose
        gg = score.analyze('key')

        if gg.mode == 'major':
            inte = interval.Interval(gg.tonic, pitch.Pitch('C'))
            snew = score.transpose(inte)
            majs+=1
        elif gg.mode == 'minor':
            inte = interval.Interval(gg.tonic, pitch.Pitch('A'))
            snew = score.transpose(inte)
            mins+=1  
        chords_example[path] = snew
    return chords_example
chords_example = init_stuff()
done, loaded 100 midi files
# nbi:hide_in
def generate_chords_4bars_good(random_seed):
    #generate random chords
    random.seed(random_seed)
    s = random.randint(0, len(chords_example)-1)
    
    c = chords_example[list(chords_example)[s]]
    if len(c) == 2:
        print('yelp')
        tst = stream.Score()
        tst.append(c)
        tst.append(copy.deepcopy(c))
        return tst
    c = c[0:4]
    c.show('text')

interact(generate_chords_4bars_good, random_seed=(0,100));
# nbi:hide_in