
// compile like
// gcc -D_DEBUGP test_usched.c -lm 

#include <string.h>
#include <stdlib.h>
#include <stdio.h>
#include <setjmp.h>

/* declaration */
#define _DEBUGP 1
#define MAX_UTHREADS 512


unsigned int numSwitches = 0;

unsigned int num_uthreads = 0;
typedef void (usched_func_ptr_t) (void *);

typedef struct uthread_input_args_T {
    int       uthreadId;
    int       data1;
    char      data2[100];
} uthread_input_args_t;

typedef struct usched_uthread_T {
    jmp_buf            env;
    usched_func_ptr_t *func;
    void              *context;
    int                status;
    char              *uthread_name;
    void              *stk;
} usched_uthread_t;

/* global store for scheduler context (index 0) and
   user thread contexts (index 1..MAX_UTHREADS) */
usched_uthread_t usched_uthread_store[MAX_UTHREADS+1];

/* Always points to the user thread currently active */
usched_uthread_t *usched_cur_uthread;

/* Always points to the scheduler context */
usched_uthread_t *usched_sched_uthread;

int uthread_register(usched_func_ptr_t func, 
                     uthread_input_args_t *args,  
                     char * uthread_name);
void uthread_runall();
int uthread_exit();

#define uthread_switch(old,new) if (!setjmp(old)) longjmp(new, 1)


#define uthread_yield()      if (!setjmp(usched_cur_uthread->env)) \
                                 longjmp(usched_sched_uthread->env,1) 


int uthread_register(usched_func_ptr_t func, 
                     uthread_input_args_t *args, 
                     char * uthread_name) {

    usched_uthread_t *u;

    if (num_uthreads++ > MAX_UTHREADS) return 1;

    u = &(usched_uthread_store[num_uthreads]);
    u->func = func;

    /* copy input args into heap */
    u->context = malloc(sizeof(uthread_input_args_t));
    memcpy(u->context,args,sizeof(uthread_input_args_t));

    u->status = 0;
    u->uthread_name = strdup(uthread_name);

    /* initialise context with reasonable values, using setjmp */
    setjmp(u->env);

    /* Create a stack for the user thread, this is done on the heap */
    u->stk = calloc(1,4096);

    /* Store address for the first context switch to the user thread:
       this is - of course - the function entry point */
    u->env[0].__jmpbuf[JB_PC] = (int)func;

    /* The stack base pointer gets the highest available 
       address of the stack section. Stack grows towards
       smaller addresses */
    u->env[0].__jmpbuf[JB_BP] = (int)(u->stk + 4096 - sizeof(int32_t));

    /* Place the input parameter (pointer value u->context) on the stack
       (this is i386/gcc-specific stack layout) */
    *((void**)u->env[0].__jmpbuf[JB_BP]) = u->context;

    /* Now the top of stack is behind the u->context parameter 
       written to the stack */
    u->env[0].__jmpbuf[JB_SP] = u->env[0].__jmpbuf[JB_BP] - sizeof(int32_t);

    /* Put the address of the exit function to the top of the stack. 
       This has the effect
       that any return instruction from the user thread will not
       jump directly into the scheduler but call the uthread_exit()
       function which performs safe switching to the scheduler context.
    */
    *((int *)u->env[0].__jmpbuf[JB_SP]) = (int)uthread_exit; /* old EIP */

    return 0;
}

int uthread_exit() {
    printf("Terminating user thread %s\n",usched_cur_uthread->uthread_name);

    /* user thread state is now "terminated" */
    usched_cur_uthread->status = 2;

    /* switch back to the scheduler */
    longjmp(usched_sched_uthread->env,1);
};

void uthread_runall() {
    int active, cur_uthread_idx;
    usched_sched_uthread = &usched_uthread_store[0];
 
    do {
	active=0;
	for (cur_uthread_idx=1; cur_uthread_idx<=num_uthreads; cur_uthread_idx++) {
	    usched_cur_uthread = &usched_uthread_store[cur_uthread_idx];
	    switch (usched_cur_uthread->status) {
		case 0: case 1: /* never been run or running */
		    numSwitches++;
		    active = 1;
		    usched_cur_uthread->status = 1;
		    uthread_switch(usched_sched_uthread->env,usched_cur_uthread->env);
		    break;
		case 2:
                    /* this uthread is already terminated */
                    /* we dont need the stack anymore */

		    if ( usched_cur_uthread->stk ) {
			printf("Termination - free context of uthread %s\n",
			       usched_cur_uthread->uthread_name);
			free(usched_cur_uthread->stk);
			usched_cur_uthread->env[0].__jmpbuf[JB_SP] = 0;
			usched_cur_uthread->stk = 0;
		    }
		    break;
	    } 
	} 
    } while (numSwitches < 50 && active);
    
}


/* test */
int test_status=0;
int f2_trigger=0;

void aux_func(int x, int t) {

#ifdef _DEBUGP
    printf("aux_func(): yield thread %d: x = %d\n",t,x);fflush(NULL);
#endif

    x++;
    uthread_yield();
    
#ifdef _DEBUGP
    printf("aux_func(): resume thread %d: x = %d\n",t,x);fflush(NULL);
#endif

}

int f(int k) { return 2*k;}


void func_1(void *a) {
    int i=0;
    double z = 3.14;
    uthread_input_args_t *args = (uthread_input_args_t *)a;

#ifdef _DEBUGP
    printf("%s ### INITIAL ### id = %d  data1 = %d data2 = %s i = %d\n",
	   usched_cur_uthread->uthread_name,
           args->uthreadId,args->data1,args->data2,i);
#endif

    uthread_yield();

#ifdef _DEBUGP
    printf("%s ### BEFORE WHILE ### processing: i = %d\n",
	   usched_cur_uthread->uthread_name,i);
#endif

    


    while (1) {

        z = sin(z+i);

#ifdef _DEBUGP
	printf("%s processing: i = %d z = %f\n",
               usched_cur_uthread->uthread_name,i,z);
#endif

        i++;
	aux_func(i,args->uthreadId);
	uthread_yield();
    }

}

void func_2(void *a) {
    int i = 100;
    int j=0;
    double z = 2.9;
    uthread_input_args_t *args = (uthread_input_args_t *)a;

#ifdef _DEBUGP
    printf("%s ### INITIAL ### id is %d data1 = %d data2 = %s z = %f\n",
           usched_cur_uthread->uthread_name,
           args->uthreadId,args->data1,args->data2,z);
#endif


    uthread_yield();

    while (i < 110) {

	z = exp(z/i);

#ifdef _DEBUGP
	printf("%s processing i= %d j = %d z = %f\n",
	       usched_cur_uthread->uthread_name,i,j,z);
#endif
	aux_func(i,args->uthreadId);
        
        i++;
        j++;
	uthread_yield();
    }
}

int main() {
    uthread_input_args_t args;

    args.uthreadId = 1;
    args.data1 = 1001;
    strcpy(args.data2,"first thread ...");
    uthread_register(&func_1, &args, "func_1");

    args.uthreadId = 2;
    args.data1 = 1002;
    strcpy(args.data2,"second thread ...");
    uthread_register(&func_2, &args, "func_2");

    uthread_runall();
    return 0;
}
