"""
Copyright 2022 ColdQuanta Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import matplotlib.pyplot as plt
import numpy as np
import scipy.optimize

from braket.aws import AwsDevice
from braket.pulse import PulseSequence
from braket.circuits import Circuit

device = AwsDevice("arn:aws:braket:us-west-1::device/qpu/rigetti/Aspen-M-2")


def idle_gate(qubit, idle_time):
    """
    Return an "idle" gate that enforces a delay between consecutive instructions that address the given qubit.
    """
    control_frame = device.frames[f"q{qubit}_rf_frame"]
    pulse_sequence = PulseSequence().delay(control_frame, idle_time)
    return Circuit().pulse_gate(qubit, pulse_sequence)


def x_pulse(qubit, angle=np.pi):
    """Rotate the given qubit by a given angle around the X axis."""
    return Circuit().rx(qubit, angle)


def y_pulse(qubit, angle=np.pi):
    """
    Rotate the given qubit by a given angle around the Y axis.

    Not all devices natively support rotations around the Y axis, so we construct one 'manually' by sandwiching an X pulse with
    appropriate rotations around the Z axis.
    """
    return (
        Circuit()
        .rz(qubit, -0.5 * np.pi)
        .rx(qubit, angle)
        .rz(qubit, +0.5 * np.pi)
    )


def idle_with_XY4(qubit, idle_time, num_cycles):
    """
    Construct an XY4 pulse sequence with a fixed number of cycles and a total duration 'idle_time' for a given qubit.
    """
    total_pulse_number = 4 * num_cycles                        # there are four pi pulses per XY4 cycle
    pulse_spacing = idle_time / total_pulse_number             # the delay time between consecutive pulses (namely, 2 tau)
    pulse_padding = idle_gate(qubit, pulse_spacing / 2)        # each pulse gets padded by half of the delay time between pulses
    padded_x = pulse_padding + x_pulse(qubit) + pulse_padding  # padded X pulse
    padded_y = pulse_padding + y_pulse(qubit) + pulse_padding  # padded Y pulse
    cycle = [padded_x, padded_y, padded_x, padded_y]           # a single XY4 cycle
    return Circuit(cycle * num_cycles)                         # a cirucit of 'num_cycles' XY4 cycles


################################################################################
# fitting and plotting methods


def exp_decay(time, initial_value, decay_rate):
    return initial_value * np.exp(-time * decay_rate)


def plot_and_fit_fidelities(times, fidelities_idle, fidelities_XY4, ylabel=None):
    """
    Plot fidelities acquired from an experiment comparing qubit idling to an XY4 sequence.
    Add curves that fit fidelities to an exponential decay.
    """
    fit_params_idle, _ = scipy.optimize.curve_fit(exp_decay, times, fidelities_idle)
    fit_params_XY4, _ = scipy.optimize.curve_fit(exp_decay, times, fidelities_XY4)

    fit_times = np.linspace(times[0], times[-1], 100)
    fit_fidelities_idle = [exp_decay(time, *fit_params_idle) for time in fit_times]
    fit_fidelities_XY4 = [exp_decay(time, *fit_params_XY4) for time in fit_times]

    plt.figure(figsize=(4, 3))
    plt.plot(times * 1e6, fidelities_idle, "ro", label="idle")
    plt.plot(fit_times * 1e6, fit_fidelities_idle, "r-")
    plt.plot(times * 1e6, fidelities_XY4, "bo", label="XY4")
    plt.plot(fit_times * 1e6, fit_fidelities_XY4, "b-")

    plt.ylim(0, 1)
    plt.xlabel("Idle time (µs)")
    plt.ylabel(ylabel) if ylabel else plt.ylabel("Fidelities")
    plt.legend(loc="best")
    plt.tight_layout()
    plt.show()
