Goto sanos source index
//
// smbproto.c
//
// SMB protocol
//
// Copyright (C) 2002 Michael Ringgaard. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
//
// 1. Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// 3. Neither the name of the project nor the names of its contributors
// may be used to endorse or promote products derived from this software
// without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
// OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
// HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
// LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
// OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
// SUCH DAMAGE.
//
#include <os/krnl.h>
#include "smb.h"
struct smb_server *servers = NULL;
struct smb *smb_init(struct smb_share *share, int aux) {
struct smb *smb;
if (aux) {
smb = (struct smb *) share->server->auxbuf;
} else {
smb = (struct smb *) share->server->buffer;
}
memset(smb, 0, sizeof(struct smb));
return smb;
}
int smb_send(struct smb_share *share, struct smb *smb, unsigned char cmd, int params, char *data, int datasize) {
int len;
int rc;
char *p;
len = SMB_HEADER_LEN + params * 2 + 2 + datasize;
smb->type = SMB_SESSION_MESSAGE;
smb->len[0] = (len & 0xFF0000) >> 16;
smb->len[1] = (len & 0xFF00) >> 8;
smb->len[2] = (len & 0xFF);
smb->protocol[0] = 0xFF;
smb->protocol[1] = 'S';
smb->protocol[2] = 'M';
smb->protocol[3] = 'B';
smb->cmd = cmd;
smb->tid = share->tid;
smb->uid = share->server->uid;
smb->wordcount = (unsigned char) params;
smb->flags = (1 << 3);
smb->flags2 = 1;
p = (char *) smb->params.words + params * 2;
*((unsigned short *) p) = (unsigned short) datasize;
if (datasize) memcpy(p + 2, data, datasize);
rc = send(share->server->sock, (char *) smb, len + 4, 0);
if (rc < 0) return rc;
if (rc != len + 4) return -EIO;
return 0;
}
int smb_recv(struct smb_share *share, struct smb *smb) {
int len;
int rc;
while (1) {
rc = recv_fully(share->server->sock, (char *) smb, 4, 0);
if (rc < 0) return rc;
if (rc != 4) return -EIO;
if (smb->type == SMB_SESSION_MESSAGE) {
break;
} else if (smb->type != SMB_SESSION_KEEP_ALIVE) {
return -EIO;
}
}
len = smb->len[2] | (smb->len[1] << 8) | (smb->len[0] << 16);
if (len < 4 || len > SMB_MAX_BUFFER) return -EMSGSIZE;
rc = recv_fully(share->server->sock, (char *) &smb->protocol, len, 0);
if (rc < 0) return rc;
if (rc != len) return -EIO;
if (smb->protocol[0] != 0xFF || smb->protocol[1] != 'S' || smb->protocol[2] != 'M' || smb->protocol[3] != 'B') return -EPROTO;
if (smb->error_class != SMB_SUCCESS) return smb_errno(smb);
return 0;
}
int smb_request(struct smb_share *share, struct smb *smb, unsigned char cmd, int params, char *data, int datasize, int retry) {
int rc;
if (retry) {
rc = smb_check_connection(share);
if (rc < 0) return rc;
}
rc = smb_send(share, smb, cmd, params, data, datasize);
if ((rc == -ECONN || rc == -ERST) && retry) {
rc = smb_reconnect(share);
if (rc < 0) return rc;
rc = smb_send(share, smb, cmd, params, data, datasize);
if (rc < 0) return rc;
} else {
if (rc < 0) return rc;
}
rc = smb_recv(share, smb);
if (rc < 0) return rc;
return 0;
}
int smb_trans_send(struct smb_share *share, unsigned short cmd,
void *params, int paramlen,
void *data, int datalen,
int maxparamlen, int maxdatalen) {
struct smb *smb;
int wordcount = 15;
int paramofs = ROUNDUP(SMB_HEADER_LEN + 2 * wordcount + 2 + 3);
int dataofs = ROUNDUP(paramofs + paramlen);
int bcc = dataofs + datalen - (SMB_HEADER_LEN + 2 * wordcount + 2);
int len = SMB_HEADER_LEN + 2 * wordcount + 2 + bcc;
char *p;
int rc;
rc = smb_check_connection(share);
if (rc < 0) return rc;
smb = (struct smb *) share->server->buffer;
memset(smb, 0, sizeof(struct smb));
smb->type = SMB_SESSION_MESSAGE;
smb->len[0] = (len > 0xFF0000) >> 16;
smb->len[1] = (len & 0xFF00) >> 8;
smb->len[2] = (len & 0xFF);
smb->protocol[0] = 0xFF;
smb->protocol[1] = 'S';
smb->protocol[2] = 'M';
smb->protocol[3] = 'B';
smb->cmd = SMB_COM_TRANSACTION2;
smb->tid = share->tid;
smb->uid = share->server->uid;
smb->wordcount = wordcount;
smb->flags = (1 << 3);
smb->flags2 = 1;
smb->params.req.trans.total_parameter_count = paramlen;
smb->params.req.trans.total_data_count = datalen;
smb->params.req.trans.max_parameter_count = maxparamlen;
smb->params.req.trans.max_data_count = maxdatalen;
smb->params.req.trans.max_setup_count = 0;
smb->params.req.trans.parameter_count = paramlen;
smb->params.req.trans.parameter_offset = paramofs;
smb->params.req.trans.data_count = datalen;
smb->params.req.trans.data_offset = dataofs;
smb->params.req.trans.setup_count = 1;
smb->params.req.trans.setup[0] = cmd;
p = (char *) smb->params.words + wordcount * 2;
*((unsigned short *) p) = (unsigned short) bcc;
p = (char *) smb + 4;
if (params) memcpy(p + paramofs, params, paramlen);
if (data) memcpy(p + dataofs, data, datalen);
rc = send(share->server->sock, (char *) smb, len + 4, 0);
if (rc == -ECONN || rc == -ERST) {
rc = smb_reconnect(share);
if (rc < 0) return rc;
rc = send(share->server->sock, (char *) smb, len + 4, 0);
if (rc < 0) return rc;
if (rc != len + 4) return -EIO;
} else {
if (rc < 0) return rc;
if (rc != len + 4) return -EIO;
}
return 0;
}
int smb_trans_recv(struct smb_share *share,
void *params, int *paramlen,
void *data, int *datalen) {
struct smb *smb;
int paramofs, paramdisp, paramcnt;
int dataofs, datadisp, datacnt;
int rc;
while (1) {
// Receive next block
smb = (struct smb *) share->server->buffer;
rc = smb_recv(share, smb);
if (rc < 0) return rc;
// Copy parameters
paramofs = smb->params.rsp.trans.parameter_offset;
paramdisp = smb->params.rsp.trans.parameter_displacement;
paramcnt = smb->params.rsp.trans.parameter_count;
if (params) {
if (paramdisp + paramcnt > *paramlen) return -EBUF;
if (paramcnt) memcpy((char *) params + paramdisp, (char *) smb + paramofs + 4, paramcnt);
}
// Copy data
dataofs = smb->params.rsp.trans.data_offset;
datadisp = smb->params.rsp.trans.data_displacement;
datacnt = smb->params.rsp.trans.data_count;
if (data) {
if (datadisp + datacnt > *datalen) return -EBUF;
if (datacnt) memcpy((char *) data + datadisp, (char *) smb + dataofs + 4, datacnt);
}
// Check for last block
if (paramdisp + paramcnt == smb->params.rsp.trans.total_parameter_count &&
datadisp + datacnt == smb->params.rsp.trans.total_data_count) {
*paramlen = smb->params.rsp.trans.total_parameter_count;
*datalen = smb->params.rsp.trans.total_data_count;
return 0;
}
}
}
int smb_trans(struct smb_share *share,
unsigned short cmd,
void *reqparams, int reqparamlen,
void *reqdata, int reqdatalen,
void *rspparams, int *rspparamlen,
void *rspdata, int *rspdatalen) {
int rc;
int dummyparamlen;
int dummydatalen;
if (!rspparamlen) {
dummyparamlen = 0;
rspparamlen = &dummyparamlen;
}
if (!rspdatalen) {
dummydatalen = 0;
rspdatalen = &dummydatalen;
}
rc = smb_trans_send(share, cmd, reqparams, reqparamlen, reqdata, reqdatalen, *rspparamlen, *rspdatalen);
if (rc < 0) return rc;
rc = smb_trans_recv(share, rspparams, rspparamlen, rspdata, rspdatalen);
if (rc == -ERST) {
rc = smb_reconnect(share);
if (rc < 0) return rc;
rc = smb_trans_send(share, cmd, reqparams, reqparamlen, reqdata, reqdatalen, *rspparamlen, *rspdatalen);
if (rc < 0) return rc;
rc = smb_trans_recv(share, rspparams, rspparamlen, rspdata, rspdatalen);
}
if (rc < 0) return rc;
return 0;
}
int smb_connect_tree(struct smb_share *share) {
struct smb *smb;
int rc;
char *p;
char buf[SMB_NAMELEN];
// Connect to share
smb = smb_init(share, 1);
smb->params.req.connect.andx.cmd = 0xFF;
smb->params.req.connect.password_length = strlen(share->server->password);
p = buf;
p = addstr(p, share->server->password);
p = addpathz(p, share->sharename);
p = addstrz(p, SMB_SERVICE_DISK);
rc = smb_request(share, smb, SMB_COM_TREE_CONNECT_ANDX, 4, buf, p - buf, 0);
if (rc < 0) return rc;
share->tid = smb->tid;
share->mounttime = time(0);
return 0;
}
int smb_disconnect_tree(struct smb_share *share) {
struct smb *smb;
// Disconnect from share
smb = smb_init(share, 1);
smb_request(share, smb, SMB_COM_TREE_DISCONNECT, 0, NULL, 0, 0);
share->tid = 0xFFFF;
return 0;
}
int smb_connect(struct smb_share *share) {
struct smb_server *server = share->server;
struct smb *smb;
struct sockaddr_in sin;
int rc;
unsigned short max_mpx_count;
char *p;
char buf[SMB_NAMELEN];
// Connect to SMB server
rc = socket(AF_INET, SOCK_STREAM, IPPROTO_IP, &server->sock);
if (rc < 0) return rc;
memset(&sin, 0, sizeof(sin));
sin.sin_family = AF_INET;
sin.sin_addr.s_addr = server->ipaddr.addr;
sin.sin_port = htons(server->port);
rc = connect(server->sock, (struct sockaddr *) &sin, sizeof(sin));
if (rc < 0) goto error;
// Negotiate protocol version
smb = smb_init(share, 1);
rc = smb_request(share, smb, SMB_COM_NEGOTIATE, 0, "\002NT LM 0.12", 12, 0);
if (rc < 0) goto error;
if (smb->params.rsp.negotiate.dialect_index == 0xFFFF) {
rc = -EREMOTEIO;
goto error;
}
server->tzofs = smb->params.rsp.negotiate.server_timezone * 60;
server->server_caps = smb->params.rsp.negotiate.capabilities;
server->max_buffer_size = smb->params.rsp.negotiate.max_buffer_size;
max_mpx_count = smb->params.rsp.negotiate.max_mpx_count;
// Setup session
smb = smb_init(share, 1);
smb->params.req.setup.andx.cmd = 0xFF;
smb->params.req.setup.max_buffer_size = (unsigned short) SMB_MAX_BUFFER;
smb->params.req.setup.max_mpx_count = max_mpx_count;
smb->params.req.setup.ansi_password_length = strlen(server->password);
smb->params.req.setup.unicode_password_length = 0;
smb->params.req.setup.capabilities = SMB_CAP_NT_SMBS;
p = buf;
p = addstr(p, server->password);
p = addstrz(p, server->username);
p = addstrz(p, server->domain);
p = addstrz(p, SMB_CLIENT_OS);
p = addstrz(p, SMB_CLIENT_LANMAN);
rc = smb_request(share, smb, SMB_COM_SESSION_SETUP_ANDX, 13, buf, p - buf, 0);
if (rc < 0) goto error;
server->uid = smb->uid;
return 0;
error:
if (server->sock) {
closesocket(server->sock);
server->sock = NULL;
}
return rc;
}
int smb_disconnect(struct smb_share *share) {
struct smb_server *server = share->server;
struct smb *smb;
if (server->sock) {
// Logoff server
if (server->uid != 0xFFFF) {
smb = smb_init(share, 1);
smb->params.andx.cmd = 0xFF;
smb_request(share, smb, SMB_COM_LOGOFF_ANDX, 2, NULL, 0, 0);
server->uid = 0xFFFF;
}
// Close socket
closesocket(server->sock);
server->sock = NULL;
}
return 0;
}
int smb_get_connection(struct smb_share *share, struct ip_addr *ipaddr, unsigned short port, char *domain, char *username, char *password) {
struct smb_server *server;
int rc;
// Try to find existing connection to server
server = servers;
while (server) {
if (ip_addr_cmp(&server->ipaddr, ipaddr) && server->port == port) {
// Add share to server
share->server = server;
share->next = server->shares;
server->shares = share;
server->refcnt++;
return 0;
}
server = server->next;
}
// Allocate new server block
server = (struct smb_server *) kmalloc(sizeof(struct smb_server));
if (!server) return -ENOMEM;
memset(server, 0, sizeof(struct smb_server));
server->ipaddr = *ipaddr;
server->port = port;
strcpy(server->domain, domain);
strcpy(server->username, username);
strcpy(server->password, password);
server->uid = 0xFFFF;
init_mutex(&server->lock, 0);
// Add share to server
share->tid = 0xFFFF;
share->server = server;
share->next = server->shares;
server->shares = share;
server->refcnt++;
// Connect to server
rc = smb_connect(share);
if (rc < 0) {
share->server = NULL;
kfree(server);
return rc;
}
// Add server to server list
server->next = servers;
servers = server;
return 0;
}
int smb_release_connection(struct smb_share *share) {
struct smb_server *server = share->server;
if (!server) return 0;
if (--server->refcnt > 0) {
// Remove share from server list
if (server->shares == share) {
server->shares = share->next;
} else {
struct smb_share *s;
for (s = server->shares; s != NULL; s = s->next) {
if (s->next == share) {
s->next = share->next;
break;
}
}
}
share->server = NULL;
return 0;
}
// Disconnect from server
smb_disconnect(share);
// Remove server block
if (servers == server) {
servers = server->next;
} else {
struct smb_server *s;
for (s = servers; s != NULL; s = s->next) {
if (s->next == server) {
s->next = server->next;
break;
}
}
}
kfree(server);
share->server = NULL;
return 0;
}
int smb_check_connection(struct smb_share *share) {
struct smb_server *server = share->server;
int rc;
if (!server->sock) {
// Create new connection to server
rc = smb_connect(share);
if (rc < 0) return rc;
}
if (share->tid == 0xFFFF) {
// Reconnect share
rc = smb_connect_tree(share);
if (rc < 0) return rc;
}
return 0;
}
int smb_reconnect(struct smb_share *share) {
struct smb_server *server = share->server;
struct smb_share *s;
int rc;
s = server->shares;
while (s) {
s->tid = 0xFFFF;
s = s->next;
}
server->uid = 0xFFFF;
if (server->sock) smb_disconnect(share);
rc = smb_connect(share);
if (rc < 0) return rc;
rc = smb_connect_tree(share);
if (rc < 0) return rc;
return 0;
}