import base64
import json
import math
import os
import tempfile

try:
    import bpy
    import mathutils
    from . import scene
    from .. import utils
    from .. import config
    from ..tools import bouncecurve as bc
    from ..bouncecurve import bcutils
except ModuleNotFoundError as e:
    import scene
    import utils
    import config
    import bcutils
    from tools import bouncecurve as bc


def get_strokes():
    obj = bpy.context.active_object
    if not obj:
        return []
    if bpy.app.version >= (4, 3, 0):
        if obj.type != 'GREASEPENCIL':
            return []
    else:
        if obj.type != 'GPENCIL':
            return []
    if bpy.app.version < (4, 2, 0):
        bpy.ops.object.mode_set(mode='EDIT_GPENCIL')
    elif bpy.app.version >= (4, 3, 0):
        bpy.ops.object.mode_set(mode='OBJECT')
        bpy.ops.object.editmode_toggle()
    else:
        bpy.ops.object.mode_set(mode='OBJECT')
        bpy.ops.gpencil.editmode_toggle()
    if len(obj.data.layers) < 1:
        return []
    layer = obj.data.layers[0]
    if len(layer.frames) < 1:
        return []
    f = layer.frames[0]

    strokes = []
    if bpy.app.version >= (4, 3, 0):
        for s in f.drawing.strokes:
            points = []
            for p in s.points:
                points.append((p.position[0], p.position[1], p.position[2]))
            strokes.append(points)
    else:
        for s in f.strokes:
            points = []
            for p in s.points:
                points.append((p.co[0], p.co[1], p.co[2]))
            strokes.append(points)
    bpy.ops.object.mode_set(mode='OBJECT')
    return strokes


def get_score_staff_lines(col_name):
    objs = bpy.data.collections[col_name].objects[:5]
    return [
        (obj.data.splines[0].bezier_points[0].co[:2], obj.data.splines[0].bezier_points[1].co[:2]) for obj in objs
    ]


def select_all_objects(col):
    for child_col in col.children:
        select_all_objects(child_col)
    for obj in col.objects:
        obj.select_set(True)


def group_objects(context, data, col):
    scale = data['scale']
    offset = data['offset']
    obj_scales = context.scene.bc_obj_svg_score.scale
    col_others = utils.get_or_new_collection('其他', col)
    col_others_col = utils.get_or_new_collection('竖线', col_others)
    # group notes
    for part_ind, part in enumerate(data['parts']):
        col_part = utils.get_or_new_collection(f'音符{part_ind + 1}', col)
        note_paths = part['note_paths']
        note_paths.append([loc for locs in part['tie_locations'] for loc in locs])
        for note_path in note_paths:
            real_note_path = [(mathutils.Vector(point) * mathutils.Vector(scale) + mathutils.Vector(
                offset)) * mathutils.Vector((
                obj_scales[0], obj_scales[2])) for point in note_path]
            for obj in col.objects:
                if len(obj.data.splines[0].bezier_points) != 4:
                    continue
                for ind_real_point, real_point in enumerate(real_note_path):
                    if abs(obj.location[0] - real_point[0]) > 0.03 or abs(obj.location[1] - real_point[1]) > 0.015:
                        continue
                    utils.move_object(col_part, obj)
                    real_note_path.pop(ind_real_point)
                    break
    # group staff lines
    col_staff = utils.get_or_new_collection('谱线', col)
    col_staff_row = utils.get_or_new_collection('横线', col_staff)
    col_staff_col = utils.get_or_new_collection('竖线', col_staff)
    staff_ys = set()
    for obj in col.objects:
        points = obj.data.splines[0].bezier_points
        if len(points) != 2:
            continue
        if abs(points[0].co[0]) < 0.001:
            continue
        if abs(points[0].co[1]) < 0.002 and abs(points[1].co[1]) < 0.002:
            staff_ys.add(obj.location[1])
            utils.move_object(col_staff_row, obj)
    for obj in col.objects:
        points = obj.data.splines[0].bezier_points
        if len(points) != 2:
            continue
        if abs(points[0].co[1]) < 0.005:
            continue
        loc = obj.location
        match_count = 0
        for point in points:
            for staff_y in staff_ys:
                if abs(loc[1] + point.co[1] * obj.scale[1] - staff_y) < 0.02:
                    match_count += 1
                    break
            if match_count == 2:
                utils.move_object(col_staff_col, obj)
                break
        if match_count < 2:
            utils.move_object(col_others_col, obj)

    # group others
    for obj in col.objects:
        utils.move_object(col_others, obj)


class GenerateScoreOperator(bpy.types.Operator):
    bl_idname = "object.bc_gen_score"
    bl_label = "生成乐谱"
    bl_options = {'REGISTER', 'UNDO'}

    def execute(self, context):
        # check if it's midi file
        if (len(context.scene.bc_midi) < 1
                or not context.scene.bc_midi.endswith('.mid')
        ):
            self.report({'ERROR'}, '请选择mid文件')
            return {'FINISHED'}
        if not os.path.exists(context.scene.bc_midi) or not os.path.isfile(context.scene.bc_midi):
            print(context.scene.bc_midi)
            self.report({'ERROR'}, 'mid文件不存在')
            return {'FINISHED'}
        key = context.scene.mc_auth_key.strip()
        secret = context.scene.mc_auth_secret.strip()
        result, msg, data, meta = utils.analyze_mid(key, secret, context.scene.bc_midi)
        if not result:
            self.report({'ERROR'}, '生成失败' if msg is None else msg)
            return {'FINISHED'}
        staff_lines = data.get('staff_lines', [])

        if len(staff_lines) < 2:
            self.report({'INFO'}, '无法确定乐谱坐标')
            return {'FINISHED'}
        parts = data.get('parts', [])
        for p in parts:
            times = p.get('times', [])
            # note_paths = p.get('note_paths', [])
            if len(times) < 1:
                self.report({'INFO'}, '乐谱数据异常1')
                continue
            # if len(times) != len(note_paths):
            #     self.report({'INFO'}, '乐谱数据异常2')
            #     continue
        svg = data.get('svg')

        with tempfile.TemporaryDirectory() as tmpdir:
            name = context.scene.bc_midi.replace('.mid', '.svg')
            filename = os.path.basename(name)
            dest_file = os.path.join(tmpdir, filename)
            with open(dest_file, 'wb') as f:
                f.write(base64.b64decode(svg))

            if context.scene.bc_import_svg_as_gspencil:
                if 'gpencil_import_svg' in dir(bpy.ops.wm):
                    try:
                        bpy.ops.wm.gpencil_import_svg(filepath="", directory=os.path.dirname(dest_file),
                                                      files=[{"name": os.path.basename(dest_file)}])
                    except:
                        pass
                elif 'grease_pencil_import_svg' in dir(bpy.ops.wm):
                    try:
                        bpy.ops.wm.grease_pencil_import_svg(filepath="", directory=os.path.dirname(dest_file),
                                                            files=[{"name": os.path.basename(dest_file)}])
                    except:
                        pass
                else:
                    self.report({'ERROR'}, '该 Blender 版本无法生成蜡笔乐谱，请使用合适版本')
                    return {'FINISHED'}
                context.scene.bc_obj_svg_score = bpy.context.active_object
                context.scene.bc_obj_svg_score.rotation_euler = (math.radians(270), 0, 0)
                context.scene.bc_obj_svg_score.scale = (20, 20, 20)
                scale, offset = bcutils.cal_scale_and_offset(staff_lines, get_strokes()[:5])
            else:
                bpy.ops.import_curve.svg(filepath=dest_file)
                scale = 40
                for obj in bpy.data.collections[filename].objects:
                    obj.scale = (scale, scale, scale)
                col = bpy.data.collections[filename]
                context.scene.bc_obj_svg_score = col.objects[0]
                scale, offset = bcutils.cal_scale_and_offset(staff_lines, get_score_staff_lines(filename))
        del data['svg']
        data['scale'] = scale
        data['offset'] = offset
        context.scene.bc_score_data = json.dumps(data)
        scene.mc_expire_tips = meta.get('expired_at', '')

        if not context.scene.bc_import_svg_as_gspencil:
            bpy.ops.object.select_all(action='DESELECT')
            select_all_objects(col)
            bpy.ops.object.origin_set(type='ORIGIN_GEOMETRY', center='MEDIAN')
            group_objects(context, data, col)

        return {'FINISHED'}


def generate_bounce_curve(context, points, factor):
    return bc.create(context, points, factor)


class RectifyOffsetOperator(bpy.types.Operator):
    bl_idname = "object.bc_rectify_offset"
    bl_label = "偏移校正"
    bl_description = "调整好曲线后再来校正，校正后再添加其他关键帧"
    bl_options = {'REGISTER', 'UNDO'}

    def execute(self, context):
        data = context.scene.bc_score_data
        try:
            j = json.loads(data)
        except:
            self.report({'ERROR'}, '乐谱数据丢失，请重新生成乐谱')
            return {'FINISHED'}
        obj_bounce_curve = context.scene.bc_obj_bounce
        if not obj_bounce_curve:
            self.report({'INFO'}, '请选择一个弹跳物体')

        obj_scales = context.scene.bc_obj_svg_score.scale
        if obj_scales[0] != obj_scales[2]:
            self.report({'INFO'}, '请保证乐谱X/Y轴以等比例缩放')
            return {'FINISHED'}
        parts = j.get('parts', [])
        obj_bounce_ind = 0
        for part in parts:
            times = [t.get('onset_time') for t in part.get('times', [])]
            if len(times) < 1:
                self.report({'INFO'}, '乐谱数据异常1')
                continue
            note_paths = part.get('note_paths', [])

            for obj_ind, item in enumerate(context.scene.bc_real_obj_bounce[obj_bounce_ind:]):
                if obj_ind >= len(note_paths):
                    break
                path = note_paths[obj_ind]
                ani_times = [t for i, t in enumerate(times) if path[i] is not None]
                obj = item.obj
                location = obj.location.copy()
                curve = obj.modifiers['curve'].object
                if curve is None:
                    self.report({'INFO'}, '弹跳物体{}未设置曲线'.format(obj.name))
                    return {'FINISHED'}

                floor_dists = calculate_curve_section_length(curve)
                if len(floor_dists) != len(ani_times):
                    self.report({'INFO'}, '弹跳物体{}对应弹跳曲线变更不符合规则'.format(obj.name))
                    return {'FINISHED'}
                delta = location[0] - floor_dists[0]
                obj.animation_data_clear()
                for i, dist in enumerate(floor_dists):
                    origin_frame = int(ani_times[i] * bpy.context.scene.render.fps)
                    frame = origin_frame + context.scene.bc_animation_delay
                    utils.set_keyframe(obj, 'location', frame, (dist + delta, 0, 0))
                for curve in obj.animation_data.action.fcurves:
                    if curve.data_path == 'location':
                        for kf in curve.keyframe_points:
                            kf.interpolation = 'LINEAR'
            obj_bounce_ind += len(note_paths)
        return {'FINISHED'}


class GenerateBounceCurveOperator(bpy.types.Operator):
    bl_idname = "object.bc_gen_bounce_curve"
    bl_label = "生成弹跳曲线/动画"
    bl_options = {'REGISTER', 'UNDO'}

    def execute(self, context):
        data = context.scene.bc_score_data
        try:
            j = json.loads(data)
        except:
            self.report({'ERROR'}, '乐谱数据丢失，请重新生成乐谱')
            return {'FINISHED'}
        obj_score = context.scene.bc_obj_svg_score
        if not obj_score:
            self.report({'INFO'}, '请选择乐谱')
            return {'FINISHED'}
        obj_scales = obj_score.scale
        if obj_scales[0] != obj_scales[2]:
            self.report({'INFO'}, '请保证乐谱X/Y轴以等比例缩放')
            return {'FINISHED'}
        obj_bullet = context.scene.bc_obj_bounce
        if not obj_bullet:
            self.report({'INFO'}, '选择弹跳对象才会自动生成动画')
        scale = j.get('scale')
        offset = j.get('offset')
        parts = j.get('parts', [])
        context.scene.bc_real_obj_bounce.clear()
        col_bullet = utils.get_or_new_collection('子弹')
        col_curve = utils.get_or_new_collection('跳线')
        for part in parts:
            times = [t.get('onset_time') for t in part.get('times', [])]
            note_paths = part.get('note_paths', [])
            if len(times) < 1:
                self.report({'INFO'}, '乐谱数据异常1')
                continue
            for path in note_paths:
                if len(times) != len(path):
                    print('times len: {}, chords len:{}'.format(len(times), len(path)))
                    self.report({'INFO'}, '乐谱数据异常2')
                    continue
                point_indices = []
                tmp_points = []
                for point_ind, point in enumerate(path):
                    if point is None:
                        continue
                    point_indices.append(point_ind)
                    tmp_points.append((mathutils.Vector(point) * mathutils.Vector(scale) + mathutils.Vector(
                        offset)) * mathutils.Vector((
                        obj_scales[0], obj_scales[2])))

                points = [mathutils.Vector((p.x, p.y, 0)) for p in tmp_points]
                points.reverse()
                curve = generate_bounce_curve(self, points, context.scene.bc_bounce_factor)
                utils.move_object(col_curve, curve)
                if not obj_bullet:
                    continue

                obj_bullet_copy = utils.copy_obj(obj_bullet)
                utils.move_object(col_bullet, obj_bullet_copy)
                item = context.scene.bc_real_obj_bounce.add()
                item.obj = obj_bullet_copy
                modifier = obj_bullet_copy.modifiers.new('curve', 'CURVE')
                modifier.object = curve
                modifier.deform_axis = 'POS_X'

                obj_bullet_copy.location = (0, 0, 0)
                obj_bullet_copy.lock_location[1] = True
                obj_bullet_copy.lock_location[2] = True

                floor_dists = calculate_curve_section_length(curve)
                for i, dist in enumerate(floor_dists):
                    frame = int(times[point_indices[i]] * bpy.context.scene.render.fps)
                    # print('frame: {} distance:{}'.format(frame, dist))
                    utils.set_keyframe(obj_bullet_copy, 'location', frame, (dist, 0, 0))
                for curve in obj_bullet_copy.animation_data.action.fcurves:
                    if curve.data_path == 'location':
                        for kf in curve.keyframe_points:
                            kf.interpolation = 'LINEAR'
        return {'FINISHED'}


def calculate_curve_section_length(curve):
    tmp_curve = utils.copy_obj(curve)
    tmp_curve.data.splines[0].resolution_u = 64
    for obj in bpy.data.objects:
        obj.select_set(False)
    bpy.context.view_layer.objects.active = tmp_curve
    tmp_curve.select_set(True)
    bpy.ops.object.convert(target='MESH')
    section_lens = []
    if tmp_curve.type == 'MESH':
        mesh = tmp_curve.data
        acc = 0
        last_co = mesh.vertices[0].co
        for vertex in mesh.vertices:
            acc += (vertex.co - last_co).length
            last_co = vertex.co
            if vertex.co.z == 0:
                section_lens.append(acc)
    bpy.data.objects.remove(tmp_curve, do_unlink=True)
    return section_lens


class GenerateCameraOperator(bpy.types.Operator):
    bl_idname = "object.bc_gen_camera_animation"
    bl_label = "生成摄像机动画"
    bl_options = {'REGISTER', 'UNDO'}

    def execute(self, context):
        camera = context.scene.bc_camera
        if camera is None:
            self.report({'ERROR'}, '请先选中摄像机')
            return {'FINISHED'}
        data = context.scene.bc_score_data
        try:
            j = json.loads(data)
        except:
            self.report({'ERROR'}, '乐谱数据丢失，请重新生成乐谱')
            return {'FINISHED'}
        scale = j.get('scale')
        offset = j.get('offset')
        parts = j.get('parts', [])
        obj_score = context.scene.bc_obj_svg_score
        obj_scales = obj_score.scale
        if obj_scales[0] != obj_scales[2]:
            self.report({'INFO'}, '请保证乐谱X/Y轴以等比例缩放')
            return {'FINISHED'}
        part = parts[0]
        times = [t.get('onset_time') for t in part.get('times', [])]
        note_paths = part.get('note_paths', [])
        notes = []
        for ind_note in range(len(note_paths[0])):
            for path in note_paths:
                if path[ind_note] is None:
                    continue
                notes.append(path[ind_note])
                break
        if len(times) < 1:
            self.report({'INFO'}, '乐谱数据异常1')
            return
        if len(times) != len(notes):
            print('times len: {}, chords len:{}'.format(len(times), len(notes)))
            self.report({'INFO'}, '乐谱数据异常2')
            return
        tmp_points = []
        for note_ind, note in enumerate(notes):
            tmp_points.append((mathutils.Vector(note) * mathutils.Vector(scale) + mathutils.Vector(
                offset)) * mathutils.Vector((
                obj_scales[0], obj_scales[2])))
        camera.animation_data_clear()
        location = camera.location.copy()
        utils.set_keyframe(camera, 'location', 0, location)
        for ind in range(0, len(times), 10):
            loc = (tmp_points[ind][0] + location[0] - tmp_points[0][0], location[1], location[2])
            utils.set_keyframe(camera, 'location', int(times[ind] * bpy.context.scene.render.fps), loc)
        for curve in camera.animation_data.action.fcurves:
            if curve.data_path == 'location':
                for kf in curve.keyframe_points:
                    kf.interpolation = 'LINEAR'
        return {'FINISHED'}


class BounceCurvePanel(bpy.types.Panel):
    bl_label = "跳动曲线"
    bl_idname = "TOOL_PT_mercury_client_panel_bounce_curve"
    bl_space_type = 'PROPERTIES'
    bl_region_type = 'WINDOW'
    bl_context = 'world'

    @classmethod
    def poll(cls, context):
        return True

    def draw(self, context):
        layout = self.layout
        layout.use_property_split = True
        # layout.use_property_decorate = False
        layout.label(text="全局配置")
        layout.row().prop(context.scene, "mc_expire_tips", placeholder='无需填写')
        layout.row().prop(context.scene, "mc_auth_key")
        layout.row().prop(context.scene, "mc_auth_secret")
        layout.separator()
        layout.prop(context.scene, "bc_midi")
        layout.prop(context.scene, "bc_import_svg_as_gspencil")
        layout.operator(GenerateScoreOperator.bl_idname, text="生成乐谱")
        layout.separator()
        layout.label(text="动画配置")
        # layout.prop(context.scene, "bc_obj_svg_score")
        layout.prop(context.scene, "bc_obj_bounce")
        layout.prop(context.scene, "bc_bounce_factor")
        layout.operator(GenerateBounceCurveOperator.bl_idname, text="生成弹跳曲线/动画")
        layout.separator()
        layout.prop(context.scene, "bc_animation_delay")
        layout.operator(RectifyOffsetOperator.bl_idname, text='偏移校正')
        layout.separator()
        layout.prop(context.scene, "bc_camera")
        layout.operator(GenerateCameraOperator.bl_idname, text="生成摄像机动画")
