00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020 #include <stdio.h>
00021 #include <stdlib.h>
00022
00023 #include "sba.h"
00024
00025 static void sba_crsm_print(struct sba_crsm *sm, FILE *fp);
00026 static void sba_crsm_build(struct sba_crsm *sm, int *m, int nr, int nc);
00027
00028
00029 void sba_crsm_alloc(struct sba_crsm *sm, int nr, int nc, int nnz)
00030 {
00031 int msz;
00032
00033 sm->nr=nr;
00034 sm->nc=nc;
00035 sm->nnz=nnz;
00036 msz=2*nnz+nr+1;
00037 sm->val=(int *)malloc(msz*sizeof(int));
00038 if(!sm->val){
00039 fprintf(stderr, "SBA: memory allocation request failed in sba_crsm_alloc() [nr=%d, nc=%d, nnz=%d]\n", nr, nc, nnz);
00040 exit(1);
00041 }
00042 sm->colidx=sm->val+nnz;
00043 sm->rowptr=sm->colidx+nnz;
00044 }
00045
00046
00047 void sba_crsm_free(struct sba_crsm *sm)
00048 {
00049 sm->nr=sm->nc=sm->nnz=-1;
00050 free(sm->val);
00051 sm->val=sm->colidx=sm->rowptr=NULL;
00052 }
00053
00054 static void sba_crsm_print(struct sba_crsm *sm, FILE *fp)
00055 {
00056 register int i;
00057
00058 fprintf(fp, "matrix is %dx%d, %d non-zeros\nval: ", sm->nr, sm->nc, sm->nnz);
00059 for(i=0; i<sm->nnz; ++i)
00060 fprintf(fp, "%d ", sm->val[i]);
00061 fprintf(fp, "\ncolidx: ");
00062 for(i=0; i<sm->nnz; ++i)
00063 fprintf(fp, "%d ", sm->colidx[i]);
00064 fprintf(fp, "\nrowptr: ");
00065 for(i=0; i<=sm->nr; ++i)
00066 fprintf(fp, "%d ", sm->rowptr[i]);
00067 fprintf(fp, "\n");
00068 }
00069
00070
00071 static void sba_crsm_build(struct sba_crsm *sm, int *m, int nr, int nc)
00072 {
00073 int nnz;
00074 register int i, j, k;
00075
00076
00077 for(i=nnz=0; i<nr; ++i)
00078 for(j=0; j<nc; ++j)
00079 if(m[i*nc+j]!=0) ++nnz;
00080
00081 sba_crsm_alloc(sm, nr, nc, nnz);
00082
00083
00084 for(i=k=0; i<nr; ++i){
00085 sm->rowptr[i]=k;
00086 for(j=0; j<nc; ++j)
00087 if(m[i*nc+j]!=0){
00088 sm->val[k]=m[i*nc+j];
00089 sm->colidx[k++]=j;
00090 }
00091 }
00092 sm->rowptr[nr]=nnz;
00093 }
00094
00095
00096 int sba_crsm_elmidx(struct sba_crsm *sm, int i, int j)
00097 {
00098 register int low, high, mid, diff;
00099
00100 low=sm->rowptr[i];
00101 high=sm->rowptr[i+1]-1;
00102
00103
00104 while(low<=high){
00105
00106
00107
00108
00109 mid=(low+high)>>1;
00110 diff=j-sm->colidx[mid];
00111 if(diff<0)
00112 high=mid-1;
00113 else if(diff>0)
00114 low=mid+1;
00115 else
00116 return mid;
00117 }
00118
00119 return -1;
00120 }
00121
00122
00123
00124
00125
00126 int sba_crsm_elmidxp(struct sba_crsm *sm, int i, int j, int jp, int jpidx)
00127 {
00128 register int low, high, mid, diff;
00129
00130 diff=j-jp;
00131 if(diff>0){
00132 low=jpidx+1;
00133 high=sm->rowptr[i+1]-1;
00134 }
00135 else if(diff==0)
00136 return jpidx;
00137 else{
00138 low=sm->rowptr[i];
00139 high=jpidx-1;
00140 }
00141
00142
00143 while(low<=high){
00144
00145
00146
00147
00148 mid=(low+high)>>1;
00149 diff=j-sm->colidx[mid];
00150 if(diff<0)
00151 high=mid-1;
00152 else if(diff>0)
00153 low=mid+1;
00154 else
00155 return mid;
00156 }
00157
00158 return -1;
00159 }
00160
00161
00162
00163
00164
00165
00166 int sba_crsm_row_elmidxs(struct sba_crsm *sm, int i, int *vidxs, int *jidxs)
00167 {
00168 register int j, k;
00169
00170 for(j=sm->rowptr[i], k=0; j<sm->rowptr[i+1]; ++j, ++k){
00171 vidxs[k]=j;
00172 jidxs[k]=sm->colidx[j];
00173 }
00174
00175 return k;
00176 }
00177
00178
00179
00180
00181
00182
00183 int sba_crsm_col_elmidxs(struct sba_crsm *sm, int j, int *vidxs, int *iidxs)
00184 {
00185 register int *nextrowptr=sm->rowptr+1;
00186 register int i, l;
00187 register int low, high, mid, diff;
00188
00189 for(i=l=0; i<sm->nr; ++i){
00190 low=sm->rowptr[i];
00191 high=nextrowptr[i]-1;
00192
00193
00194 while(low<=high){
00195
00196
00197 mid=(low+high)>>1;
00198 diff=j-sm->colidx[mid];
00199 if(diff<0)
00200 high=mid-1;
00201 else if(diff>0)
00202 low=mid+1;
00203 else{
00204 vidxs[l]=mid;
00205 iidxs[l++]=i;
00206 break;
00207 }
00208 }
00209 }
00210
00211 return l;
00212 }
00213
00214
00215
00216
00217
00218
00219
00220
00221
00222
00223
00224
00225
00226
00227
00228
00229
00230
00231 #if 0
00232
00233
00234
00235
00236 int sba_crsm_common_row(struct sba_crsm *sm, int j, int k)
00237 {
00238 register int i, low, high, mid, diff;
00239
00240 if(j==k) return 1;
00241
00242 for(i=0; i<sm->nr; ++i){
00243 low=sm->rowptr[i];
00244 high=sm->rowptr[i+1]-1;
00245 if(j<sm->colidx[low] || j>sm->colidx[high] ||
00246 k<sm->colidx[low] || k>sm->colidx[high])
00247 continue;
00248
00249
00250 while(low<=high){
00251 mid=(low+high)>>1;
00252 diff=j-sm->colidx[mid];
00253 if(diff<0)
00254 high=mid-1;
00255 else if(diff>0)
00256 low=mid+1;
00257 else
00258 goto jfound;
00259 }
00260
00261 continue;
00262
00263 jfound:
00264 if(j>k){
00265 low=sm->rowptr[i];
00266 high=mid-1;
00267 }
00268 else{
00269 low=mid+1;
00270 high=sm->rowptr[i+1]-1;
00271 }
00272
00273 if(k<sm->colidx[low] || k>sm->colidx[high]) continue;
00274
00275
00276 while(low<=high){
00277 mid=(low+high)>>1;
00278 diff=k-sm->colidx[mid];
00279 if(diff<0)
00280 high=mid-1;
00281 else if(diff>0)
00282 low=mid+1;
00283 else
00284 return 1;
00285 }
00286 }
00287
00288 return 0;
00289 }
00290 #endif
00291
00292
00293 #if 0
00294
00295
00296
00297 main()
00298 {
00299 int mat[7][6]={
00300 {10, 0, 0, 0, -2, 0},
00301 {3, 9, 0, 0, 0, 3},
00302 {0, 7, 8, 7, 0, 0},
00303 {3, 0, 8, 7, 5, 0},
00304 {0, 8, 0, 9, 9, 13},
00305 {0, 4, 0, 0, 2, -1},
00306 {3, 7, 0, 9, 2, 0}
00307 };
00308
00309 struct sba_crsm sm;
00310 int i, j, k, l;
00311 int vidxs[7],
00312 jidxs[6], iidxs[7];
00313
00314
00315 sba_crsm_build(&sm, mat[0], 7, 6);
00316 sba_crsm_print(&sm, stdout);
00317
00318 for(i=0; i<7; ++i){
00319 for(j=0; j<6; ++j)
00320 printf("%3d ", ((k=sba_crsm_elmidx(&sm, i, j))!=-1)? sm.val[k] : 0);
00321 printf("\n");
00322 }
00323
00324 for(i=0; i<7; ++i){
00325 k=sba_crsm_row_elmidxs(&sm, i, vidxs, jidxs);
00326 printf("row %d\n", i);
00327 for(l=0; l<k; ++l){
00328 j=jidxs[l];
00329 printf("%d %d ", j, sm.val[vidxs[l]]);
00330 }
00331 printf("\n");
00332 }
00333
00334 for(j=0; j<6; ++j){
00335 k=sba_crsm_col_elmidxs(&sm, j, vidxs, iidxs);
00336 printf("col %d\n", j);
00337 for(l=0; l<k; ++l){
00338 i=iidxs[l];
00339 printf("%d %d ", i, sm.val[vidxs[l]]);
00340 }
00341 printf("\n");
00342 }
00343
00344 sba_crsm_free(&sm);
00345 }
00346 #endif